Skip to main content

sgr_agent/openapi/
mod.rs

1//! OpenAPI → Agent Tool: convert any API spec into a searchable, callable tool.
2//!
3//! Instead of 845 individual MCP tools for GitHub API, one `api` tool:
4//! - `api search "create issue"` → fuzzy-find endpoints
5//! - `api call repos_owner_repo_issues_post --owner=foo --repo=bar --title="bug"` → execute
6//!
7//! ## Usage
8//!
9//! ```ignore
10//! use sgr_agent::openapi::{ApiRegistry, ApiAuth};
11//!
12//! let mut registry = ApiRegistry::new();
13//! registry.add_api("github", "https://api.github.com", &spec_json, ApiAuth::Bearer("ghp_xxx".into())).unwrap();
14//! let results = registry.search("github", "create issue", 5);
15//! ```
16
17pub mod caller;
18pub mod registry;
19pub mod search;
20pub mod spec;
21
22pub use caller::ApiAuth;
23pub use registry::{
24    default_cache_dir, download_spec, find_popular, list_popular, load_api_registry,
25    load_or_download, popular_apis, search_apis_guru, ApiSpec,
26};
27pub use search::{format_results, search_endpoints, SearchResult};
28pub use spec::{filter_endpoints, parse_spec, Endpoint, Param, ParamLocation};
29
30use std::collections::HashMap;
31
32/// Registry of loaded API specs. Each API has a name, base URL, and parsed endpoints.
33#[derive(Default)]
34pub struct ApiRegistry {
35    apis: HashMap<String, LoadedApi>,
36}
37
38struct LoadedApi {
39    base_url: String,
40    endpoints: Vec<Endpoint>,
41    auth: ApiAuth,
42}
43
44impl ApiRegistry {
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Add an API from a JSON spec string.
50    pub fn add_api(
51        &mut self,
52        name: &str,
53        base_url: &str,
54        spec_json: &str,
55        auth: ApiAuth,
56    ) -> Result<usize, String> {
57        let spec: serde_json::Value =
58            serde_json::from_str(spec_json).map_err(|e| format!("Invalid JSON: {}", e))?;
59        let endpoints = parse_spec(&spec);
60        let count = endpoints.len();
61        self.apis.insert(
62            name.to_string(),
63            LoadedApi {
64                base_url: base_url.to_string(),
65                endpoints,
66                auth,
67            },
68        );
69        Ok(count)
70    }
71
72    /// Add an API from a pre-parsed spec Value.
73    pub fn add_api_from_value(
74        &mut self,
75        name: &str,
76        base_url: &str,
77        spec: &serde_json::Value,
78        auth: ApiAuth,
79    ) -> usize {
80        let endpoints = parse_spec(spec);
81        let count = endpoints.len();
82        self.apis.insert(
83            name.to_string(),
84            LoadedApi {
85                base_url: base_url.to_string(),
86                endpoints,
87                auth,
88            },
89        );
90        count
91    }
92
93    /// Load a popular API by name — auto-download + cache.
94    /// Returns endpoint count or error.
95    pub async fn load_popular(&mut self, name: &str) -> Result<usize, String> {
96        let api_spec = find_popular(name)
97            .ok_or_else(|| format!("Unknown API: {}. Available: {:?}", name, list_popular()))?;
98        self.load_spec(&api_spec).await
99    }
100
101    /// Load any ApiSpec — download (or use cache) and register.
102    pub async fn load_spec(&mut self, api_spec: &ApiSpec) -> Result<usize, String> {
103        let cache_dir = default_cache_dir();
104        let json = load_or_download(&cache_dir, &api_spec.name, &api_spec.spec_url).await?;
105
106        // Auto-detect auth from env var
107        let auth = if let Some(ref env_var) = api_spec.auth_env {
108            match std::env::var(env_var) {
109                Ok(token) if !token.is_empty() => ApiAuth::Bearer(token),
110                _ => ApiAuth::None,
111            }
112        } else {
113            ApiAuth::None
114        };
115
116        self.add_api(&api_spec.name, &api_spec.base_url, &json, auth)
117    }
118
119    /// List all loaded API names.
120    pub fn list_apis(&self) -> Vec<&str> {
121        self.apis.keys().map(|s| s.as_str()).collect()
122    }
123
124    /// Get endpoint count for an API.
125    pub fn endpoint_count(&self, api_name: &str) -> usize {
126        self.apis
127            .get(api_name)
128            .map(|a| a.endpoints.len())
129            .unwrap_or(0)
130    }
131
132    /// Search endpoints within a specific API.
133    pub fn search(&self, api_name: &str, query: &str, limit: usize) -> Vec<SearchResult> {
134        match self.apis.get(api_name) {
135            Some(api) => search_endpoints(&api.endpoints, query, limit),
136            None => Vec::new(),
137        }
138    }
139
140    /// Search across ALL loaded APIs.
141    pub fn search_all(&self, query: &str, limit: usize) -> Vec<(String, SearchResult)> {
142        let mut all: Vec<(String, SearchResult)> = Vec::new();
143        for (name, api) in &self.apis {
144            for r in search_endpoints(&api.endpoints, query, limit) {
145                all.push((name.clone(), r));
146            }
147        }
148        all.sort_by(|a, b| b.1.score.cmp(&a.1.score));
149        all.truncate(limit);
150        all
151    }
152
153    /// Find an endpoint by name within an API.
154    pub fn find_endpoint(&self, api_name: &str, endpoint_name: &str) -> Option<&Endpoint> {
155        self.apis
156            .get(api_name)?
157            .endpoints
158            .iter()
159            .find(|e| e.name == endpoint_name)
160    }
161
162    /// Call an endpoint by name.
163    pub async fn call(
164        &self,
165        api_name: &str,
166        endpoint_name: &str,
167        params: &HashMap<String, String>,
168        body: Option<&serde_json::Value>,
169    ) -> Result<String, String> {
170        let api = self
171            .apis
172            .get(api_name)
173            .ok_or_else(|| format!("API not found: {}", api_name))?;
174        let endpoint = api
175            .endpoints
176            .iter()
177            .find(|e| e.name == endpoint_name)
178            .ok_or_else(|| format!("Endpoint not found: {}", endpoint_name))?;
179
180        caller::call_api(&api.base_url, endpoint, params, body, &api.auth).await
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use serde_json::json;
188
189    fn github_spec() -> String {
190        json!({
191            "paths": {
192                "/repos/{owner}/{repo}": {
193                    "get": {
194                        "summary": "Get a repository",
195                        "parameters": [
196                            { "name": "owner", "in": "path", "required": true, "schema": { "type": "string" } },
197                            { "name": "repo", "in": "path", "required": true, "schema": { "type": "string" } }
198                        ]
199                    }
200                },
201                "/repos/{owner}/{repo}/issues": {
202                    "get": {
203                        "summary": "List issues",
204                        "parameters": [
205                            { "name": "owner", "in": "path", "required": true, "schema": { "type": "string" } },
206                            { "name": "repo", "in": "path", "required": true, "schema": { "type": "string" } },
207                            { "name": "state", "in": "query", "schema": { "type": "string" } }
208                        ]
209                    },
210                    "post": {
211                        "summary": "Create an issue",
212                        "parameters": [
213                            { "name": "owner", "in": "path", "required": true, "schema": { "type": "string" } },
214                            { "name": "repo", "in": "path", "required": true, "schema": { "type": "string" } }
215                        ]
216                    }
217                },
218                "/users": {
219                    "get": { "summary": "List users", "parameters": [] }
220                }
221            }
222        })
223        .to_string()
224    }
225
226    #[test]
227    fn add_api_and_count() {
228        let mut reg = ApiRegistry::new();
229        let count = reg
230            .add_api(
231                "github",
232                "https://api.github.com",
233                &github_spec(),
234                ApiAuth::None,
235            )
236            .unwrap();
237        assert_eq!(count, 4);
238        assert_eq!(reg.endpoint_count("github"), 4);
239    }
240
241    #[test]
242    fn list_apis() {
243        let mut reg = ApiRegistry::new();
244        reg.add_api(
245            "github",
246            "https://api.github.com",
247            &github_spec(),
248            ApiAuth::None,
249        )
250        .unwrap();
251        let names = reg.list_apis();
252        assert_eq!(names, vec!["github"]);
253    }
254
255    #[test]
256    fn find_endpoint_by_name() {
257        let mut reg = ApiRegistry::new();
258        reg.add_api(
259            "gh",
260            "https://api.github.com",
261            &github_spec(),
262            ApiAuth::None,
263        )
264        .unwrap();
265        let ep = reg.find_endpoint("gh", "repos_owner_repo_issues_post");
266        assert!(ep.is_some());
267        assert_eq!(ep.unwrap().method, "POST");
268    }
269
270    #[test]
271    fn find_nonexistent_endpoint() {
272        let mut reg = ApiRegistry::new();
273        reg.add_api(
274            "gh",
275            "https://api.github.com",
276            &github_spec(),
277            ApiAuth::None,
278        )
279        .unwrap();
280        assert!(reg.find_endpoint("gh", "nonexistent").is_none());
281        assert!(reg.find_endpoint("nope", "anything").is_none());
282    }
283
284    #[test]
285    fn search_within_api() {
286        let mut reg = ApiRegistry::new();
287        reg.add_api(
288            "gh",
289            "https://api.github.com",
290            &github_spec(),
291            ApiAuth::None,
292        )
293        .unwrap();
294        let results = reg.search("gh", "issue", 5);
295        assert!(!results.is_empty());
296    }
297
298    #[test]
299    fn search_nonexistent_api() {
300        let reg = ApiRegistry::new();
301        let results = reg.search("nope", "test", 5);
302        assert!(results.is_empty());
303    }
304
305    #[test]
306    fn invalid_json_returns_error() {
307        let mut reg = ApiRegistry::new();
308        let err = reg
309            .add_api("bad", "https://example.com", "not json", ApiAuth::None)
310            .unwrap_err();
311        assert!(err.contains("Invalid JSON"));
312    }
313
314    #[test]
315    fn search_all_across_apis() {
316        let mut reg = ApiRegistry::new();
317        reg.add_api(
318            "gh",
319            "https://api.github.com",
320            &github_spec(),
321            ApiAuth::None,
322        )
323        .unwrap();
324        reg.add_api(
325            "other",
326            "https://other.com",
327            &json!({"paths": {"/items": {"get": {"summary": "List items"}}}}).to_string(),
328            ApiAuth::None,
329        )
330        .unwrap();
331        let results = reg.search_all("list", 10);
332        assert!(results.len() >= 2); // "List issues" + "List users" + "List items"
333    }
334}