1mod extract;
24mod resolve;
25mod tool;
26
27use std::collections::HashMap;
28use std::path::Path;
29
30use openapiv3::{OpenAPI, ReferenceOr};
31use rig::tool::ToolDyn;
32
33use crate::extract::{extract_body_schema, extract_param_info};
34use crate::resolve::Resolver;
35use crate::tool::{HttpMethod, OpenApiTool};
36
37pub struct OpenApiToolset {
46 tools: Vec<OpenApiTool>,
47}
48
49pub struct OpenApiToolsetBuilder {
51 spec_str: String,
52 base_url: Option<String>,
53 client: Option<reqwest::Client>,
54 hidden_context: HashMap<String, String>,
55}
56
57impl OpenApiToolsetBuilder {
58 pub fn base_url(mut self, url: impl Into<String>) -> Self {
60 self.base_url = Some(url.into());
61 self
62 }
63
64 pub fn client(mut self, client: reqwest::Client) -> Self {
66 self.client = Some(client);
67 self
68 }
69
70 pub fn hidden_context(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
73 self.hidden_context.insert(key.into(), value.into());
74 self
75 }
76
77 pub fn bearer_token(self, token: &str) -> Self {
79 use reqwest::header;
80 let mut headers = header::HeaderMap::new();
81 let mut auth_value =
82 header::HeaderValue::from_str(&format!("Bearer {token}")).expect("invalid token");
83 auth_value.set_sensitive(true);
84 headers.insert(header::AUTHORIZATION, auth_value);
85
86 let client = reqwest::Client::builder()
87 .default_headers(headers)
88 .build()
89 .expect("failed to build reqwest client");
90 self.client(client)
91 }
92
93 pub fn build(self) -> anyhow::Result<OpenApiToolset> {
95 OpenApiToolset::build_inner(
96 &self.spec_str,
97 self.base_url.as_deref(),
98 self.client,
99 self.hidden_context,
100 )
101 }
102}
103
104impl OpenApiToolset {
105 pub fn from_file(path: impl AsRef<Path>) -> anyhow::Result<Self> {
107 let content = std::fs::read_to_string(path)?;
108 Self::from_spec_str(&content)
109 }
110
111 pub fn from_spec_str(spec_str: &str) -> anyhow::Result<Self> {
113 Self::build_inner(spec_str, None, None, HashMap::new())
114 }
115
116 pub fn builder(spec_str: &str) -> OpenApiToolsetBuilder {
118 OpenApiToolsetBuilder {
119 spec_str: spec_str.to_string(),
120 base_url: None,
121 client: None,
122 hidden_context: HashMap::new(),
123 }
124 }
125
126 pub fn builder_from_file(path: impl AsRef<Path>) -> anyhow::Result<OpenApiToolsetBuilder> {
128 let content = std::fs::read_to_string(path)?;
129 Ok(OpenApiToolsetBuilder {
130 spec_str: content,
131 base_url: None,
132 client: None,
133 hidden_context: HashMap::new(),
134 })
135 }
136
137 fn build_inner(
138 spec_str: &str,
139 base_url_override: Option<&str>,
140 client: Option<reqwest::Client>,
141 hidden_context: HashMap<String, String>,
142 ) -> anyhow::Result<Self> {
143 let spec: OpenAPI = serde_yaml::from_str(spec_str)?;
144 let resolver = Resolver::new(&spec);
145
146 let base_url = base_url_override
147 .map(|s| s.to_string())
148 .or_else(|| spec.servers.first().map(|s| s.url.clone()))
149 .unwrap_or_else(|| "http://localhost".into());
150 let base_url = base_url.trim_end_matches('/').to_string();
151
152 let client = client.unwrap_or_default();
153 let mut tools: Vec<OpenApiTool> = Vec::new();
154
155 for (path_template, path_item_ref) in &spec.paths {
156 let ReferenceOr::Item(path_item) = path_item_ref else {
157 continue;
158 };
159
160 let methods = [
161 (HttpMethod::Get, &path_item.get),
162 (HttpMethod::Post, &path_item.post),
163 (HttpMethod::Put, &path_item.put),
164 (HttpMethod::Patch, &path_item.patch),
165 (HttpMethod::Delete, &path_item.delete),
166 ];
167
168 for (method, op) in methods {
169 let Some(op) = op else { continue };
170
171 let method_lower = method.as_str().to_lowercase();
172 let operation_id = op.operation_id.clone().unwrap_or_else(|| {
173 let path_slug = path_template.replace('/', "_");
174 let path_slug = path_slug.trim_start_matches('_');
175 format!("{}_{}", method_lower, path_slug)
176 });
177
178 let description = op
179 .summary
180 .clone()
181 .or_else(|| op.description.clone())
182 .unwrap_or_else(|| format!("{} {}", method.as_str(), path_template));
183
184 let parameters = op
185 .parameters
186 .iter()
187 .filter_map(|p| {
188 let param = resolver.resolve_parameter(p)?;
189 extract_param_info(param, &resolver)
190 })
191 .collect();
192
193 let (request_body_schema, request_body_required) = op
194 .request_body
195 .as_ref()
196 .and_then(|rb| resolver.resolve_request_body(rb))
197 .map(|body| extract_body_schema(body, &resolver))
198 .unwrap_or((None, false));
199
200 tools.push(OpenApiTool {
201 client: client.clone(),
202 base_url: base_url.clone(),
203 method,
204 path_template: path_template.clone(),
205 operation_id,
206 description,
207 parameters,
208 request_body_schema,
209 request_body_required,
210 hidden_params: hidden_context.clone(),
211 });
212 }
213 }
214
215 Ok(Self { tools })
216 }
217
218 pub fn len(&self) -> usize {
220 self.tools.len()
221 }
222
223 pub fn is_empty(&self) -> bool {
225 self.tools.is_empty()
226 }
227
228 pub fn into_tools(self) -> Vec<Box<dyn ToolDyn>> {
230 self.tools
231 .into_iter()
232 .map(|t| Box::new(t) as Box<dyn ToolDyn>)
233 .collect()
234 }
235
236 pub fn tools_with_context(&self, context: &HashMap<String, String>) -> Vec<Box<dyn ToolDyn>> {
243 self.tools
244 .iter()
245 .map(|t| {
246 let mut tool = t.clone();
247 tool.hidden_params.extend(context.clone());
248 Box::new(tool) as Box<dyn ToolDyn>
249 })
250 .collect()
251 }
252
253 pub fn context_preamble(context: &HashMap<String, String>) -> String {
257 if context.is_empty() {
258 return String::new();
259 }
260 let entries: Vec<String> = context
261 .iter()
262 .map(|(k, v)| format!("- {k} = {v}"))
263 .collect();
264 format!(
265 "The following context is available. Use these values when calling tools:\n{}",
266 entries.join("\n")
267 )
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use serde_json::Value;
275
276 const MINIMAL_SPEC: &str = r#"
277openapi: "3.0.0"
278info:
279 title: Test
280 version: "1.0"
281servers:
282 - url: https://api.example.com
283paths:
284 /users/{id}:
285 get:
286 operationId: getUser
287 summary: Get a user by id
288 parameters:
289 - name: id
290 in: path
291 required: true
292 schema:
293 type: string
294 description: The user id
295 responses:
296 "200":
297 description: OK
298"#;
299
300 const MULTI_METHOD_SPEC: &str = r#"
301openapi: "3.0.0"
302info:
303 title: Test
304 version: "1.0"
305servers:
306 - url: https://api.example.com
307paths:
308 /users:
309 get:
310 operationId: listUsers
311 summary: List all users
312 parameters:
313 - name: limit
314 in: query
315 required: false
316 schema:
317 type: integer
318 description: Max results
319 responses:
320 "200":
321 description: OK
322 post:
323 operationId: createUser
324 summary: Create a user
325 requestBody:
326 required: true
327 content:
328 application/json:
329 schema:
330 type: object
331 properties:
332 name:
333 type: string
334 email:
335 type: string
336 required:
337 - name
338 responses:
339 "201":
340 description: Created
341 /users/{id}:
342 get:
343 operationId: getUser
344 summary: Get a user
345 parameters:
346 - name: id
347 in: path
348 required: true
349 schema:
350 type: string
351 responses:
352 "200":
353 description: OK
354 delete:
355 operationId: deleteUser
356 summary: Delete a user
357 parameters:
358 - name: id
359 in: path
360 required: true
361 schema:
362 type: string
363 responses:
364 "204":
365 description: Deleted
366"#;
367
368 const REF_SPEC: &str = r#"
369openapi: "3.0.0"
370info:
371 title: Test
372 version: "1.0"
373servers:
374 - url: https://api.example.com
375paths:
376 /items/{id}:
377 get:
378 operationId: getItem
379 summary: Get an item
380 parameters:
381 - $ref: '#/components/parameters/ItemId'
382 responses:
383 "200":
384 description: OK
385components:
386 parameters:
387 ItemId:
388 name: id
389 in: path
390 required: true
391 schema:
392 type: string
393 description: The item id
394"#;
395
396 #[test]
397 fn parse_minimal_spec() {
398 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
399 assert_eq!(toolset.len(), 1);
400 }
401
402 #[test]
403 fn parse_multi_method_spec() {
404 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
405 assert_eq!(toolset.len(), 4);
406 }
407
408 #[test]
409 fn tool_names_match_operation_ids() {
410 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
411 let tools = toolset.into_tools();
412 let names: Vec<String> = tools.iter().map(|t| t.name()).collect();
413 assert!(names.contains(&"listUsers".to_string()));
414 assert!(names.contains(&"createUser".to_string()));
415 assert!(names.contains(&"getUser".to_string()));
416 assert!(names.contains(&"deleteUser".to_string()));
417 }
418
419 #[test]
420 fn fallback_operation_id_when_missing() {
421 let spec = r#"
422openapi: "3.0.0"
423info:
424 title: Test
425 version: "1.0"
426paths:
427 /health:
428 get:
429 summary: Health check
430 responses:
431 "200":
432 description: OK
433"#;
434 let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
435 let tools = toolset.into_tools();
436 assert_eq!(tools[0].name(), "get_health");
437 }
438
439 #[test]
440 fn base_url_from_spec() {
441 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
442 let tools = toolset.into_tools();
443 assert_eq!(tools.len(), 1);
444 }
445
446 #[test]
447 fn builder_base_url_override() {
448 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
449 .base_url("https://override.com")
450 .build()
451 .unwrap();
452 assert_eq!(toolset.len(), 1);
453 }
454
455 #[test]
456 fn builder_bearer_token() {
457 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
458 .bearer_token("test-token-123")
459 .build()
460 .unwrap();
461 assert_eq!(toolset.len(), 1);
462 }
463
464 #[test]
465 fn builder_custom_client() {
466 let client = reqwest::Client::builder()
467 .timeout(std::time::Duration::from_secs(30))
468 .build()
469 .unwrap();
470 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
471 .client(client)
472 .build()
473 .unwrap();
474 assert_eq!(toolset.len(), 1);
475 }
476
477 #[test]
478 fn builder_all_options() {
479 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
480 .base_url("https://custom.api.com")
481 .bearer_token("sk-123")
482 .build()
483 .unwrap();
484 assert_eq!(toolset.len(), 1);
485 }
486
487 #[test]
488 fn base_url_defaults_to_localhost() {
489 let spec = r#"
490openapi: "3.0.0"
491info:
492 title: Test
493 version: "1.0"
494paths:
495 /ping:
496 get:
497 operationId: ping
498 summary: Ping
499 responses:
500 "200":
501 description: OK
502"#;
503 let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
504 assert_eq!(toolset.len(), 1);
505 }
506
507 #[test]
508 fn empty_spec_produces_no_tools() {
509 let spec = r#"
510openapi: "3.0.0"
511info:
512 title: Test
513 version: "1.0"
514paths: {}
515"#;
516 let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
517 assert!(toolset.is_empty());
518 }
519
520 #[test]
521 fn invalid_yaml_returns_error() {
522 let result = OpenApiToolset::from_spec_str("not: [valid: yaml: {{");
523 assert!(result.is_err());
524 }
525
526 #[tokio::test]
527 async fn tool_definition_has_correct_fields() {
528 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
529 let tools = toolset.into_tools();
530 let def = tools[0].definition("".into()).await;
531
532 assert_eq!(def.name, "getUser");
533 assert_eq!(def.description, "Get a user by id");
534 }
535
536 #[tokio::test]
537 async fn tool_definition_path_param_schema() {
538 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
539 let tools = toolset.into_tools();
540 let def = tools[0].definition("".into()).await;
541
542 let props = def.parameters["properties"].as_object().unwrap();
543 assert!(props.contains_key("id"));
544
545 let required = def.parameters["required"].as_array().unwrap();
546 assert!(required.contains(&Value::String("id".into())));
547 }
548
549 #[tokio::test]
550 async fn tool_definition_query_param_not_required() {
551 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
552 let tools = toolset.into_tools();
553 let list_tool = tools.iter().find(|t| t.name() == "listUsers").unwrap();
554 let def = list_tool.definition("".into()).await;
555
556 let props = def.parameters["properties"].as_object().unwrap();
557 assert!(props.contains_key("limit"));
558
559 let required = def.parameters["required"].as_array().unwrap();
560 assert!(!required.contains(&Value::String("limit".into())));
561 }
562
563 #[tokio::test]
564 async fn tool_definition_request_body_schema() {
565 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
566 let tools = toolset.into_tools();
567 let create_tool = tools.iter().find(|t| t.name() == "createUser").unwrap();
568 let def = create_tool.definition("".into()).await;
569
570 let props = def.parameters["properties"].as_object().unwrap();
571 assert!(props.contains_key("body"));
572
573 let required = def.parameters["required"].as_array().unwrap();
574 assert!(required.contains(&Value::String("body".into())));
575 }
576
577 #[tokio::test]
578 async fn ref_parameters_are_resolved() {
579 let toolset = OpenApiToolset::from_spec_str(REF_SPEC).unwrap();
580 let tools = toolset.into_tools();
581 assert_eq!(tools.len(), 1);
582
583 let def = tools[0].definition("".into()).await;
584 let props = def.parameters["properties"].as_object().unwrap();
585 assert!(props.contains_key("id"));
586 }
587
588 #[tokio::test]
589 async fn tool_definition_header_param() {
590 let spec = r#"
591openapi: "3.0.0"
592info:
593 title: Test
594 version: "1.0"
595paths:
596 /data:
597 get:
598 operationId: getData
599 summary: Get data
600 parameters:
601 - name: X-Request-Id
602 in: header
603 required: false
604 schema:
605 type: string
606 description: Correlation ID
607 responses:
608 "200":
609 description: OK
610"#;
611 let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
612 let tools = toolset.into_tools();
613 let def = tools[0].definition("".into()).await;
614
615 let props = def.parameters["properties"].as_object().unwrap();
616 assert!(props.contains_key("X-Request-Id"));
617 }
618
619 #[tokio::test]
620 async fn tool_call_with_invalid_json_returns_error() {
621 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
622 let tools = toolset.into_tools();
623 let result = tools[0].call("not json".into()).await;
624 assert!(result.is_err());
625 }
626
627 #[tokio::test]
628 async fn hidden_context_excluded_from_schema() {
629 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
630 .hidden_context("id", "123")
631 .build()
632 .unwrap();
633 let tools = toolset.into_tools();
634 let def = tools[0].definition("".into()).await;
635
636 let props = def.parameters["properties"].as_object().unwrap();
637 assert!(
638 !props.contains_key("id"),
639 "hidden param should not appear in schema"
640 );
641
642 let required = def.parameters["required"].as_array().unwrap();
643 assert!(!required.contains(&Value::String("id".into())));
644 }
645
646 #[tokio::test]
647 async fn tools_with_context_excludes_from_schema() {
648 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
649
650 let tools = toolset.tools_with_context(&HashMap::new());
652 let def = tools[0].definition("".into()).await;
653 let props = def.parameters["properties"].as_object().unwrap();
654 assert!(props.contains_key("id"));
655
656 let ctx = HashMap::from([("id".to_string(), "42".to_string())]);
658 let tools = toolset.tools_with_context(&ctx);
659 let def = tools[0].definition("".into()).await;
660 let props = def.parameters["properties"].as_object().unwrap();
661 assert!(!props.contains_key("id"));
662 }
663
664 #[test]
665 fn toolset_reusable_across_contexts() {
666 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
667
668 let ctx1 = HashMap::from([("id".to_string(), "1".to_string())]);
669 let ctx2 = HashMap::from([("id".to_string(), "2".to_string())]);
670
671 let tools1 = toolset.tools_with_context(&ctx1);
672 let tools2 = toolset.tools_with_context(&ctx2);
673
674 assert_eq!(tools1.len(), 4);
675 assert_eq!(tools2.len(), 4);
676 }
677
678 #[test]
679 fn context_preamble_generation() {
680 let ctx = HashMap::from([("user_id".to_string(), "123".to_string())]);
681 let preamble = OpenApiToolset::context_preamble(&ctx);
682 assert!(preamble.contains("user_id = 123"));
683 assert!(preamble.contains("Use these values"));
684 }
685
686 #[test]
687 fn context_preamble_empty() {
688 let preamble = OpenApiToolset::context_preamble(&HashMap::new());
689 assert!(preamble.is_empty());
690 }
691}