1mod extract;
24mod resolve;
25mod tool;
26
27use std::collections::HashMap;
28use std::path::Path;
29
30use openapiv3::{OpenAPI, ReferenceOr};
31use rig::tool::ToolDyn;
32
33use crate::extract::{extract_body_schema, extract_param_info};
34use crate::resolve::Resolver;
35use crate::tool::{HttpMethod, OpenApiTool};
36
37pub struct OpenApiToolset {
46 tools: Vec<OpenApiTool>,
47}
48
49pub struct OpenApiToolsetBuilder {
51 spec_str: String,
52 base_url: Option<String>,
53 client: Option<reqwest::Client>,
54 hidden_context: HashMap<String, String>,
55 default_headers: reqwest::header::HeaderMap,
56 static_query_params: Vec<(String, String)>,
57 basic_auth: Option<(String, String)>,
58}
59
60impl OpenApiToolsetBuilder {
61 pub fn base_url(mut self, url: impl Into<String>) -> Self {
63 self.base_url = Some(url.into());
64 self
65 }
66
67 pub fn client(mut self, client: reqwest::Client) -> Self {
69 self.client = Some(client);
70 self
71 }
72
73 pub fn hidden_context(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
76 self.hidden_context.insert(key.into(), value.into());
77 self
78 }
79
80 pub fn bearer_token(mut self, token: &str) -> Self {
82 use reqwest::header;
83 let mut auth_value =
84 header::HeaderValue::from_str(&format!("Bearer {token}")).expect("invalid token");
85 auth_value.set_sensitive(true);
86 self.default_headers
87 .insert(header::AUTHORIZATION, auth_value);
88 self
89 }
90
91 pub fn api_key_header(mut self, header_name: &str, key: &str) -> Self {
93 use reqwest::header::HeaderValue;
94 let name = reqwest::header::HeaderName::from_bytes(header_name.as_bytes())
95 .expect("invalid header name");
96 let mut value = HeaderValue::from_str(key).expect("invalid header value");
97 value.set_sensitive(true);
98 self.default_headers.insert(name, value);
99 self
100 }
101
102 pub fn api_key_query(mut self, param_name: &str, key: &str) -> Self {
104 self.static_query_params
105 .push((param_name.to_string(), key.to_string()));
106 self
107 }
108
109 pub fn basic_auth(mut self, username: &str, password: &str) -> Self {
111 self.basic_auth = Some((username.to_string(), password.to_string()));
112 self
113 }
114
115 pub fn build(self) -> anyhow::Result<OpenApiToolset> {
117 let client = if let Some(c) = self.client {
118 c
119 } else {
120 reqwest::Client::builder()
121 .default_headers(self.default_headers)
122 .build()?
123 };
124 OpenApiToolset::build_inner(
125 &self.spec_str,
126 self.base_url.as_deref(),
127 client,
128 self.hidden_context,
129 self.static_query_params,
130 self.basic_auth,
131 )
132 }
133}
134
135impl OpenApiToolset {
136 pub fn from_file(path: impl AsRef<Path>) -> anyhow::Result<Self> {
138 let content = std::fs::read_to_string(path)?;
139 Self::from_spec_str(&content)
140 }
141
142 pub fn from_spec_str(spec_str: &str) -> anyhow::Result<Self> {
144 Self::build_inner(
145 spec_str,
146 None,
147 reqwest::Client::default(),
148 HashMap::new(),
149 Vec::new(),
150 None,
151 )
152 }
153
154 pub fn builder(spec_str: &str) -> OpenApiToolsetBuilder {
156 OpenApiToolsetBuilder {
157 spec_str: spec_str.to_string(),
158 base_url: None,
159 client: None,
160 hidden_context: HashMap::new(),
161 default_headers: reqwest::header::HeaderMap::new(),
162 static_query_params: Vec::new(),
163 basic_auth: None,
164 }
165 }
166
167 pub fn builder_from_file(path: impl AsRef<Path>) -> anyhow::Result<OpenApiToolsetBuilder> {
169 let content = std::fs::read_to_string(path)?;
170 Ok(OpenApiToolsetBuilder {
171 spec_str: content,
172 base_url: None,
173 client: None,
174 hidden_context: HashMap::new(),
175 default_headers: reqwest::header::HeaderMap::new(),
176 static_query_params: Vec::new(),
177 basic_auth: None,
178 })
179 }
180
181 fn build_inner(
182 spec_str: &str,
183 base_url_override: Option<&str>,
184 client: reqwest::Client,
185 hidden_context: HashMap<String, String>,
186 static_query_params: Vec<(String, String)>,
187 basic_auth: Option<(String, String)>,
188 ) -> anyhow::Result<Self> {
189 let spec: OpenAPI = serde_yaml::from_str(spec_str)?;
190 let resolver = Resolver::new(&spec);
191
192 let base_url = base_url_override
193 .map(|s| s.to_string())
194 .or_else(|| spec.servers.first().map(|s| s.url.clone()))
195 .unwrap_or_else(|| "http://localhost".into());
196 let base_url = base_url.trim_end_matches('/').to_string();
197
198 let mut tools: Vec<OpenApiTool> = Vec::new();
199
200 for (path_template, path_item_ref) in spec.paths.iter() {
201 let ReferenceOr::Item(path_item) = path_item_ref else {
202 continue;
203 };
204
205 let methods = [
206 (HttpMethod::Get, &path_item.get),
207 (HttpMethod::Post, &path_item.post),
208 (HttpMethod::Put, &path_item.put),
209 (HttpMethod::Patch, &path_item.patch),
210 (HttpMethod::Delete, &path_item.delete),
211 ];
212
213 for (method, op) in methods {
214 let Some(op) = op else { continue };
215
216 let method_lower: String = method.as_str().to_lowercase();
217 let operation_id = op.operation_id.clone().unwrap_or_else(|| {
218 let path_slug = path_template.replace('/', "_");
219 let path_slug = path_slug.trim_start_matches('_');
220 format!("{}_{}", method_lower, path_slug)
221 });
222
223 let description = op
224 .summary
225 .clone()
226 .or_else(|| op.description.clone())
227 .unwrap_or_else(|| format!("{} {}", method.as_str(), path_template));
228
229 let parameters = op
230 .parameters
231 .iter()
232 .filter_map(|p| {
233 let param = resolver.resolve_parameter(p)?;
234 extract_param_info(param, &resolver)
235 })
236 .collect();
237
238 let (request_body_schema, request_body_required) = op
239 .request_body
240 .as_ref()
241 .and_then(|rb| resolver.resolve_request_body(rb))
242 .map(|body| extract_body_schema(body, &resolver))
243 .unwrap_or((None, false));
244
245 tools.push(OpenApiTool {
246 client: client.clone(),
247 base_url: base_url.clone(),
248 method,
249 path_template: path_template.clone(),
250 operation_id,
251 description,
252 parameters,
253 request_body_schema,
254 request_body_required,
255 hidden_params: hidden_context.clone(),
256 static_query_params: static_query_params.clone(),
257 basic_auth: basic_auth.clone(),
258 });
259 }
260 }
261
262 Ok(Self { tools })
263 }
264
265 pub fn len(&self) -> usize {
267 self.tools.len()
268 }
269
270 pub fn is_empty(&self) -> bool {
272 self.tools.is_empty()
273 }
274
275 pub fn into_tools(self) -> Vec<Box<dyn ToolDyn>> {
277 self.tools
278 .into_iter()
279 .map(|t| Box::new(t) as Box<dyn ToolDyn>)
280 .collect()
281 }
282
283 pub fn tools_with_context(&self, context: &HashMap<String, String>) -> Vec<Box<dyn ToolDyn>> {
290 self.tools
291 .iter()
292 .map(|t| {
293 let mut tool = t.clone();
294 tool.hidden_params.extend(context.clone());
295 Box::new(tool) as Box<dyn ToolDyn>
296 })
297 .collect()
298 }
299
300 pub fn context_preamble(context: &HashMap<String, String>) -> String {
304 if context.is_empty() {
305 return String::new();
306 }
307 let entries: Vec<String> = context
308 .iter()
309 .map(|(k, v)| format!("- {k} = {v}"))
310 .collect();
311 format!(
312 "The following context is available. Use these values when calling tools:\n{}",
313 entries.join("\n")
314 )
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use serde_json::Value;
322
323 const MINIMAL_SPEC: &str = r#"
324openapi: "3.0.0"
325info:
326 title: Test
327 version: "1.0"
328servers:
329 - url: https://api.example.com
330paths:
331 /users/{id}:
332 get:
333 operationId: getUser
334 summary: Get a user by id
335 parameters:
336 - name: id
337 in: path
338 required: true
339 schema:
340 type: string
341 description: The user id
342 responses:
343 "200":
344 description: OK
345"#;
346
347 const MULTI_METHOD_SPEC: &str = r#"
348openapi: "3.0.0"
349info:
350 title: Test
351 version: "1.0"
352servers:
353 - url: https://api.example.com
354paths:
355 /users:
356 get:
357 operationId: listUsers
358 summary: List all users
359 parameters:
360 - name: limit
361 in: query
362 required: false
363 schema:
364 type: integer
365 description: Max results
366 responses:
367 "200":
368 description: OK
369 post:
370 operationId: createUser
371 summary: Create a user
372 requestBody:
373 required: true
374 content:
375 application/json:
376 schema:
377 type: object
378 properties:
379 name:
380 type: string
381 email:
382 type: string
383 required:
384 - name
385 responses:
386 "201":
387 description: Created
388 /users/{id}:
389 get:
390 operationId: getUser
391 summary: Get a user
392 parameters:
393 - name: id
394 in: path
395 required: true
396 schema:
397 type: string
398 responses:
399 "200":
400 description: OK
401 delete:
402 operationId: deleteUser
403 summary: Delete a user
404 parameters:
405 - name: id
406 in: path
407 required: true
408 schema:
409 type: string
410 responses:
411 "204":
412 description: Deleted
413"#;
414
415 const REF_SPEC: &str = r#"
416openapi: "3.0.0"
417info:
418 title: Test
419 version: "1.0"
420servers:
421 - url: https://api.example.com
422paths:
423 /items/{id}:
424 get:
425 operationId: getItem
426 summary: Get an item
427 parameters:
428 - $ref: '#/components/parameters/ItemId'
429 responses:
430 "200":
431 description: OK
432components:
433 parameters:
434 ItemId:
435 name: id
436 in: path
437 required: true
438 schema:
439 type: string
440 description: The item id
441"#;
442
443 #[test]
444 fn parse_minimal_spec() {
445 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
446 assert_eq!(toolset.len(), 1);
447 }
448
449 #[test]
450 fn parse_multi_method_spec() {
451 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
452 assert_eq!(toolset.len(), 4);
453 }
454
455 #[test]
456 fn tool_names_match_operation_ids() {
457 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
458 let tools = toolset.into_tools();
459 let names: Vec<String> = tools.iter().map(|t| t.name()).collect();
460 assert!(names.contains(&"listUsers".to_string()));
461 assert!(names.contains(&"createUser".to_string()));
462 assert!(names.contains(&"getUser".to_string()));
463 assert!(names.contains(&"deleteUser".to_string()));
464 }
465
466 #[test]
467 fn fallback_operation_id_when_missing() {
468 let spec = r#"
469openapi: "3.0.0"
470info:
471 title: Test
472 version: "1.0"
473paths:
474 /health:
475 get:
476 summary: Health check
477 responses:
478 "200":
479 description: OK
480"#;
481 let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
482 let tools = toolset.into_tools();
483 assert_eq!(tools[0].name(), "get_health");
484 }
485
486 #[test]
487 fn base_url_from_spec() {
488 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
489 let tools = toolset.into_tools();
490 assert_eq!(tools.len(), 1);
491 }
492
493 #[test]
494 fn builder_base_url_override() {
495 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
496 .base_url("https://override.com")
497 .build()
498 .unwrap();
499 assert_eq!(toolset.len(), 1);
500 }
501
502 #[test]
503 fn builder_bearer_token() {
504 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
505 .bearer_token("test-token-123")
506 .build()
507 .unwrap();
508 assert_eq!(toolset.len(), 1);
509 }
510
511 #[test]
512 fn builder_custom_client() {
513 let client = reqwest::Client::builder()
514 .timeout(std::time::Duration::from_secs(30))
515 .build()
516 .unwrap();
517 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
518 .client(client)
519 .build()
520 .unwrap();
521 assert_eq!(toolset.len(), 1);
522 }
523
524 #[test]
525 fn builder_all_options() {
526 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
527 .base_url("https://custom.api.com")
528 .bearer_token("sk-123")
529 .build()
530 .unwrap();
531 assert_eq!(toolset.len(), 1);
532 }
533
534 #[test]
535 fn base_url_defaults_to_localhost() {
536 let spec = r#"
537openapi: "3.0.0"
538info:
539 title: Test
540 version: "1.0"
541paths:
542 /ping:
543 get:
544 operationId: ping
545 summary: Ping
546 responses:
547 "200":
548 description: OK
549"#;
550 let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
551 assert_eq!(toolset.len(), 1);
552 }
553
554 #[test]
555 fn empty_spec_produces_no_tools() {
556 let spec = r#"
557openapi: "3.0.0"
558info:
559 title: Test
560 version: "1.0"
561paths: {}
562"#;
563 let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
564 assert!(toolset.is_empty());
565 }
566
567 #[test]
568 fn invalid_yaml_returns_error() {
569 let result = OpenApiToolset::from_spec_str("not: [valid: yaml: {{");
570 assert!(result.is_err());
571 }
572
573 #[tokio::test]
574 async fn tool_definition_has_correct_fields() {
575 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
576 let tools = toolset.into_tools();
577 let def = tools[0].definition("".into()).await;
578
579 assert_eq!(def.name, "getUser");
580 assert_eq!(def.description, "Get a user by id");
581 }
582
583 #[tokio::test]
584 async fn tool_definition_path_param_schema() {
585 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
586 let tools = toolset.into_tools();
587 let def = tools[0].definition("".into()).await;
588
589 let props = def.parameters["properties"].as_object().unwrap();
590 assert!(props.contains_key("id"));
591
592 let required = def.parameters["required"].as_array().unwrap();
593 assert!(required.contains(&Value::String("id".into())));
594 }
595
596 #[tokio::test]
597 async fn tool_definition_query_param_not_required() {
598 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
599 let tools = toolset.into_tools();
600 let list_tool = tools.iter().find(|t| t.name() == "listUsers").unwrap();
601 let def = list_tool.definition("".into()).await;
602
603 let props = def.parameters["properties"].as_object().unwrap();
604 assert!(props.contains_key("limit"));
605
606 let required = def.parameters["required"].as_array().unwrap();
607 assert!(!required.contains(&Value::String("limit".into())));
608 }
609
610 #[tokio::test]
611 async fn tool_definition_request_body_schema() {
612 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
613 let tools = toolset.into_tools();
614 let create_tool = tools.iter().find(|t| t.name() == "createUser").unwrap();
615 let def = create_tool.definition("".into()).await;
616
617 let props = def.parameters["properties"].as_object().unwrap();
618 assert!(props.contains_key("body"));
619
620 let required = def.parameters["required"].as_array().unwrap();
621 assert!(required.contains(&Value::String("body".into())));
622 }
623
624 #[tokio::test]
625 async fn ref_parameters_are_resolved() {
626 let toolset = OpenApiToolset::from_spec_str(REF_SPEC).unwrap();
627 let tools = toolset.into_tools();
628 assert_eq!(tools.len(), 1);
629
630 let def = tools[0].definition("".into()).await;
631 let props = def.parameters["properties"].as_object().unwrap();
632 assert!(props.contains_key("id"));
633 }
634
635 #[tokio::test]
636 async fn tool_definition_header_param() {
637 let spec = r#"
638openapi: "3.0.0"
639info:
640 title: Test
641 version: "1.0"
642paths:
643 /data:
644 get:
645 operationId: getData
646 summary: Get data
647 parameters:
648 - name: X-Request-Id
649 in: header
650 required: false
651 schema:
652 type: string
653 description: Correlation ID
654 responses:
655 "200":
656 description: OK
657"#;
658 let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
659 let tools = toolset.into_tools();
660 let def = tools[0].definition("".into()).await;
661
662 let props = def.parameters["properties"].as_object().unwrap();
663 assert!(props.contains_key("X-Request-Id"));
664 }
665
666 #[tokio::test]
667 async fn tool_call_with_invalid_json_returns_error() {
668 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
669 let tools = toolset.into_tools();
670 let result = tools[0].call("not json".into()).await;
671 assert!(result.is_err());
672 }
673
674 #[tokio::test]
675 async fn hidden_context_excluded_from_schema() {
676 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
677 .hidden_context("id", "123")
678 .build()
679 .unwrap();
680 let tools = toolset.into_tools();
681 let def = tools[0].definition("".into()).await;
682
683 let props = def.parameters["properties"].as_object().unwrap();
684 assert!(
685 !props.contains_key("id"),
686 "hidden param should not appear in schema"
687 );
688
689 let required = def.parameters["required"].as_array().unwrap();
690 assert!(!required.contains(&Value::String("id".into())));
691 }
692
693 #[tokio::test]
694 async fn tools_with_context_excludes_from_schema() {
695 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
696
697 let tools = toolset.tools_with_context(&HashMap::new());
699 let def = tools[0].definition("".into()).await;
700 let props = def.parameters["properties"].as_object().unwrap();
701 assert!(props.contains_key("id"));
702
703 let ctx = HashMap::from([("id".to_string(), "42".to_string())]);
705 let tools = toolset.tools_with_context(&ctx);
706 let def = tools[0].definition("".into()).await;
707 let props = def.parameters["properties"].as_object().unwrap();
708 assert!(!props.contains_key("id"));
709 }
710
711 #[test]
712 fn toolset_reusable_across_contexts() {
713 let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
714
715 let ctx1 = HashMap::from([("id".to_string(), "1".to_string())]);
716 let ctx2 = HashMap::from([("id".to_string(), "2".to_string())]);
717
718 let tools1 = toolset.tools_with_context(&ctx1);
719 let tools2 = toolset.tools_with_context(&ctx2);
720
721 assert_eq!(tools1.len(), 4);
722 assert_eq!(tools2.len(), 4);
723 }
724
725 #[test]
726 fn context_preamble_generation() {
727 let ctx = HashMap::from([("user_id".to_string(), "123".to_string())]);
728 let preamble = OpenApiToolset::context_preamble(&ctx);
729 assert!(preamble.contains("user_id = 123"));
730 assert!(preamble.contains("Use these values"));
731 }
732
733 #[test]
734 fn context_preamble_empty() {
735 let preamble = OpenApiToolset::context_preamble(&HashMap::new());
736 assert!(preamble.is_empty());
737 }
738
739 #[test]
740 fn builder_api_key_header() {
741 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
742 .api_key_header("X-API-Key", "abc123")
743 .build()
744 .unwrap();
745 assert_eq!(toolset.len(), 1);
746 }
747
748 #[test]
749 fn builder_api_key_query() {
750 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
751 .api_key_query("api_key", "abc123")
752 .build()
753 .unwrap();
754 assert_eq!(toolset.len(), 1);
755 }
756
757 #[test]
758 fn builder_basic_auth() {
759 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
760 .basic_auth("user", "pass")
761 .build()
762 .unwrap();
763 assert_eq!(toolset.len(), 1);
764 }
765
766 #[test]
767 fn builder_multiple_auth() {
768 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
769 .bearer_token("sk-123")
770 .api_key_header("X-Tenant-Id", "tenant-abc")
771 .build()
772 .unwrap();
773 assert_eq!(toolset.len(), 1);
774 }
775
776 #[test]
777 fn api_key_query_params_stored_on_tools() {
778 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
779 .api_key_query("api_key", "secret123")
780 .build()
781 .unwrap();
782 let tool = &toolset.tools[0];
783 assert!(tool
784 .static_query_params
785 .contains(&("api_key".to_string(), "secret123".to_string())));
786 }
787
788 #[test]
789 fn multiple_api_key_queries_stack() {
790 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
791 .api_key_query("api_key", "key1")
792 .api_key_query("version", "v2")
793 .build()
794 .unwrap();
795 let tool = &toolset.tools[0];
796 assert_eq!(tool.static_query_params.len(), 2);
797 assert!(tool
798 .static_query_params
799 .contains(&("api_key".to_string(), "key1".to_string())));
800 assert!(tool
801 .static_query_params
802 .contains(&("version".to_string(), "v2".to_string())));
803 }
804
805 #[test]
806 fn basic_auth_credentials_stored_on_tools() {
807 let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
808 .basic_auth("alice", "s3cr3t")
809 .build()
810 .unwrap();
811 let tool = &toolset.tools[0];
812 assert_eq!(
813 tool.basic_auth,
814 Some(("alice".to_string(), "s3cr3t".to_string()))
815 );
816 }
817
818 #[test]
819 fn basic_auth_not_set_by_default() {
820 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
821 let tool = &toolset.tools[0];
822 assert!(tool.basic_auth.is_none());
823 }
824
825 #[test]
826 fn api_key_query_not_set_by_default() {
827 let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
828 let tool = &toolset.tools[0];
829 assert!(tool.static_query_params.is_empty());
830 }
831}