Skip to main content

rig_openapi_tools/
lib.rs

1//! Turn any OpenAPI spec into LLM-callable tools for [rig](https://docs.rs/rig-core).
2//!
3//! Parse an OpenAPI 3.0 YAML/JSON spec and get a set of tools that can be
4//! registered directly with a rig agent. Each operation in the spec becomes
5//! a tool the LLM can call.
6//!
7//! # Quick start
8//!
9//! ```no_run
10//! use rig_openapi_tools::OpenApiToolset;
11//!
12//! let spec = std::fs::read_to_string("openapi.yaml").unwrap();
13//! let toolset = OpenApiToolset::builder(&spec)
14//!     .base_url("https://api.example.com")
15//!     .bearer_token("sk-...")
16//!     .build()
17//!     .unwrap();
18//!
19//! // Register with a rig agent
20//! // agent_builder.tools(toolset.into_tools())
21//! ```
22
23mod extract;
24mod resolve;
25mod tool;
26
27use std::collections::HashMap;
28use std::path::Path;
29
30use openapiv3::{OpenAPI, ReferenceOr};
31use rig::tool::ToolDyn;
32
33use crate::extract::{extract_body_schema, extract_param_info};
34use crate::resolve::Resolver;
35use crate::tool::{HttpMethod, OpenApiTool};
36
37// ---------------------------------------------------------------------------
38// Public API
39// ---------------------------------------------------------------------------
40
41/// A set of tools generated from an OpenAPI specification.
42///
43/// Each operation in the spec becomes a tool that can be registered with a rig agent.
44/// The toolset is designed to be parsed once and reused across requests.
45pub struct OpenApiToolset {
46    tools: Vec<OpenApiTool>,
47}
48
49/// Builder for configuring an [`OpenApiToolset`].
50pub struct OpenApiToolsetBuilder {
51    spec_str: String,
52    base_url: Option<String>,
53    client: Option<reqwest::Client>,
54    hidden_context: HashMap<String, String>,
55}
56
57impl OpenApiToolsetBuilder {
58    /// Override the base URL from the spec.
59    pub fn base_url(mut self, url: impl Into<String>) -> Self {
60        self.base_url = Some(url.into());
61        self
62    }
63
64    /// Provide a pre-configured reqwest client (e.g. with default auth headers or timeouts).
65    pub fn client(mut self, client: reqwest::Client) -> Self {
66        self.client = Some(client);
67        self
68    }
69
70    /// Add a hidden context parameter that will be auto-injected into tool calls.
71    /// The LLM will not see this parameter in the tool schema.
72    pub fn hidden_context(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
73        self.hidden_context.insert(key.into(), value.into());
74        self
75    }
76
77    /// Convenience: create a client with a bearer token `Authorization` header.
78    pub fn bearer_token(self, token: &str) -> Self {
79        use reqwest::header;
80        let mut headers = header::HeaderMap::new();
81        let mut auth_value =
82            header::HeaderValue::from_str(&format!("Bearer {token}")).expect("invalid token");
83        auth_value.set_sensitive(true);
84        headers.insert(header::AUTHORIZATION, auth_value);
85
86        let client = reqwest::Client::builder()
87            .default_headers(headers)
88            .build()
89            .expect("failed to build reqwest client");
90        self.client(client)
91    }
92
93    /// Build the toolset, parsing the spec and creating tools.
94    pub fn build(self) -> anyhow::Result<OpenApiToolset> {
95        OpenApiToolset::build_inner(
96            &self.spec_str,
97            self.base_url.as_deref(),
98            self.client,
99            self.hidden_context,
100        )
101    }
102}
103
104impl OpenApiToolset {
105    /// Parse an OpenAPI spec from a YAML or JSON file.
106    pub fn from_file(path: impl AsRef<Path>) -> anyhow::Result<Self> {
107        let content = std::fs::read_to_string(path)?;
108        Self::from_spec_str(&content)
109    }
110
111    /// Parse an OpenAPI spec from a YAML or JSON string.
112    pub fn from_spec_str(spec_str: &str) -> anyhow::Result<Self> {
113        Self::build_inner(spec_str, None, None, HashMap::new())
114    }
115
116    /// Start building a toolset from a YAML or JSON string with configuration options.
117    pub fn builder(spec_str: &str) -> OpenApiToolsetBuilder {
118        OpenApiToolsetBuilder {
119            spec_str: spec_str.to_string(),
120            base_url: None,
121            client: None,
122            hidden_context: HashMap::new(),
123        }
124    }
125
126    /// Start building a toolset from a file with configuration options.
127    pub fn builder_from_file(path: impl AsRef<Path>) -> anyhow::Result<OpenApiToolsetBuilder> {
128        let content = std::fs::read_to_string(path)?;
129        Ok(OpenApiToolsetBuilder {
130            spec_str: content,
131            base_url: None,
132            client: None,
133            hidden_context: HashMap::new(),
134        })
135    }
136
137    fn build_inner(
138        spec_str: &str,
139        base_url_override: Option<&str>,
140        client: Option<reqwest::Client>,
141        hidden_context: HashMap<String, String>,
142    ) -> anyhow::Result<Self> {
143        let spec: OpenAPI = serde_yaml::from_str(spec_str)?;
144        let resolver = Resolver::new(&spec);
145
146        let base_url = base_url_override
147            .map(|s| s.to_string())
148            .or_else(|| spec.servers.first().map(|s| s.url.clone()))
149            .unwrap_or_else(|| "http://localhost".into());
150        let base_url = base_url.trim_end_matches('/').to_string();
151
152        let client = client.unwrap_or_default();
153        let mut tools: Vec<OpenApiTool> = Vec::new();
154
155        for (path_template, path_item_ref) in &spec.paths {
156            let ReferenceOr::Item(path_item) = path_item_ref else {
157                continue;
158            };
159
160            let methods = [
161                (HttpMethod::Get, &path_item.get),
162                (HttpMethod::Post, &path_item.post),
163                (HttpMethod::Put, &path_item.put),
164                (HttpMethod::Patch, &path_item.patch),
165                (HttpMethod::Delete, &path_item.delete),
166            ];
167
168            for (method, op) in methods {
169                let Some(op) = op else { continue };
170
171                let method_lower = method.as_str().to_lowercase();
172                let operation_id = op.operation_id.clone().unwrap_or_else(|| {
173                    let path_slug = path_template.replace('/', "_");
174                    let path_slug = path_slug.trim_start_matches('_');
175                    format!("{}_{}", method_lower, path_slug)
176                });
177
178                let description = op
179                    .summary
180                    .clone()
181                    .or_else(|| op.description.clone())
182                    .unwrap_or_else(|| format!("{} {}", method.as_str(), path_template));
183
184                let parameters = op
185                    .parameters
186                    .iter()
187                    .filter_map(|p| {
188                        let param = resolver.resolve_parameter(p)?;
189                        extract_param_info(param, &resolver)
190                    })
191                    .collect();
192
193                let (request_body_schema, request_body_required) = op
194                    .request_body
195                    .as_ref()
196                    .and_then(|rb| resolver.resolve_request_body(rb))
197                    .map(|body| extract_body_schema(body, &resolver))
198                    .unwrap_or((None, false));
199
200                tools.push(OpenApiTool {
201                    client: client.clone(),
202                    base_url: base_url.clone(),
203                    method,
204                    path_template: path_template.clone(),
205                    operation_id,
206                    description,
207                    parameters,
208                    request_body_schema,
209                    request_body_required,
210                    hidden_params: hidden_context.clone(),
211                });
212            }
213        }
214
215        Ok(Self { tools })
216    }
217
218    /// Return the number of tools parsed from the spec.
219    pub fn len(&self) -> usize {
220        self.tools.len()
221    }
222
223    /// Returns true if no operations were found in the spec.
224    pub fn is_empty(&self) -> bool {
225        self.tools.is_empty()
226    }
227
228    /// Consume the toolset and return tools for use with rig's `AgentBuilder::tools()`.
229    pub fn into_tools(self) -> Vec<Box<dyn ToolDyn>> {
230        self.tools
231            .into_iter()
232            .map(|t| Box::new(t) as Box<dyn ToolDyn>)
233            .collect()
234    }
235
236    /// Clone the tools with per-request context injected as hidden parameters.
237    /// The LLM will not see these parameters in tool schemas, but they will be
238    /// auto-injected into every tool call at execution time.
239    ///
240    /// This is the primary way to add per-request state (user ID, session info, etc.)
241    /// while reusing the parsed toolset across requests.
242    pub fn tools_with_context(&self, context: &HashMap<String, String>) -> Vec<Box<dyn ToolDyn>> {
243        self.tools
244            .iter()
245            .map(|t| {
246                let mut tool = t.clone();
247                tool.hidden_params.extend(context.clone());
248                Box::new(tool) as Box<dyn ToolDyn>
249            })
250            .collect()
251    }
252
253    /// Generate a preamble snippet describing the visible context for the LLM.
254    /// Include this in your agent's `.preamble()` so the LLM knows about
255    /// available context values it can use when calling tools.
256    pub fn context_preamble(context: &HashMap<String, String>) -> String {
257        if context.is_empty() {
258            return String::new();
259        }
260        let entries: Vec<String> = context
261            .iter()
262            .map(|(k, v)| format!("- {k} = {v}"))
263            .collect();
264        format!(
265            "The following context is available. Use these values when calling tools:\n{}",
266            entries.join("\n")
267        )
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use serde_json::Value;
275
276    const MINIMAL_SPEC: &str = r#"
277openapi: "3.0.0"
278info:
279  title: Test
280  version: "1.0"
281servers:
282  - url: https://api.example.com
283paths:
284  /users/{id}:
285    get:
286      operationId: getUser
287      summary: Get a user by id
288      parameters:
289        - name: id
290          in: path
291          required: true
292          schema:
293            type: string
294          description: The user id
295      responses:
296        "200":
297          description: OK
298"#;
299
300    const MULTI_METHOD_SPEC: &str = r#"
301openapi: "3.0.0"
302info:
303  title: Test
304  version: "1.0"
305servers:
306  - url: https://api.example.com
307paths:
308  /users:
309    get:
310      operationId: listUsers
311      summary: List all users
312      parameters:
313        - name: limit
314          in: query
315          required: false
316          schema:
317            type: integer
318          description: Max results
319      responses:
320        "200":
321          description: OK
322    post:
323      operationId: createUser
324      summary: Create a user
325      requestBody:
326        required: true
327        content:
328          application/json:
329            schema:
330              type: object
331              properties:
332                name:
333                  type: string
334                email:
335                  type: string
336              required:
337                - name
338      responses:
339        "201":
340          description: Created
341  /users/{id}:
342    get:
343      operationId: getUser
344      summary: Get a user
345      parameters:
346        - name: id
347          in: path
348          required: true
349          schema:
350            type: string
351      responses:
352        "200":
353          description: OK
354    delete:
355      operationId: deleteUser
356      summary: Delete a user
357      parameters:
358        - name: id
359          in: path
360          required: true
361          schema:
362            type: string
363      responses:
364        "204":
365          description: Deleted
366"#;
367
368    const REF_SPEC: &str = r#"
369openapi: "3.0.0"
370info:
371  title: Test
372  version: "1.0"
373servers:
374  - url: https://api.example.com
375paths:
376  /items/{id}:
377    get:
378      operationId: getItem
379      summary: Get an item
380      parameters:
381        - $ref: '#/components/parameters/ItemId'
382      responses:
383        "200":
384          description: OK
385components:
386  parameters:
387    ItemId:
388      name: id
389      in: path
390      required: true
391      schema:
392        type: string
393      description: The item id
394"#;
395
396    #[test]
397    fn parse_minimal_spec() {
398        let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
399        assert_eq!(toolset.len(), 1);
400    }
401
402    #[test]
403    fn parse_multi_method_spec() {
404        let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
405        assert_eq!(toolset.len(), 4);
406    }
407
408    #[test]
409    fn tool_names_match_operation_ids() {
410        let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
411        let tools = toolset.into_tools();
412        let names: Vec<String> = tools.iter().map(|t| t.name()).collect();
413        assert!(names.contains(&"listUsers".to_string()));
414        assert!(names.contains(&"createUser".to_string()));
415        assert!(names.contains(&"getUser".to_string()));
416        assert!(names.contains(&"deleteUser".to_string()));
417    }
418
419    #[test]
420    fn fallback_operation_id_when_missing() {
421        let spec = r#"
422openapi: "3.0.0"
423info:
424  title: Test
425  version: "1.0"
426paths:
427  /health:
428    get:
429      summary: Health check
430      responses:
431        "200":
432          description: OK
433"#;
434        let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
435        let tools = toolset.into_tools();
436        assert_eq!(tools[0].name(), "get_health");
437    }
438
439    #[test]
440    fn base_url_from_spec() {
441        let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
442        let tools = toolset.into_tools();
443        assert_eq!(tools.len(), 1);
444    }
445
446    #[test]
447    fn builder_base_url_override() {
448        let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
449            .base_url("https://override.com")
450            .build()
451            .unwrap();
452        assert_eq!(toolset.len(), 1);
453    }
454
455    #[test]
456    fn builder_bearer_token() {
457        let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
458            .bearer_token("test-token-123")
459            .build()
460            .unwrap();
461        assert_eq!(toolset.len(), 1);
462    }
463
464    #[test]
465    fn builder_custom_client() {
466        let client = reqwest::Client::builder()
467            .timeout(std::time::Duration::from_secs(30))
468            .build()
469            .unwrap();
470        let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
471            .client(client)
472            .build()
473            .unwrap();
474        assert_eq!(toolset.len(), 1);
475    }
476
477    #[test]
478    fn builder_all_options() {
479        let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
480            .base_url("https://custom.api.com")
481            .bearer_token("sk-123")
482            .build()
483            .unwrap();
484        assert_eq!(toolset.len(), 1);
485    }
486
487    #[test]
488    fn base_url_defaults_to_localhost() {
489        let spec = r#"
490openapi: "3.0.0"
491info:
492  title: Test
493  version: "1.0"
494paths:
495  /ping:
496    get:
497      operationId: ping
498      summary: Ping
499      responses:
500        "200":
501          description: OK
502"#;
503        let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
504        assert_eq!(toolset.len(), 1);
505    }
506
507    #[test]
508    fn empty_spec_produces_no_tools() {
509        let spec = r#"
510openapi: "3.0.0"
511info:
512  title: Test
513  version: "1.0"
514paths: {}
515"#;
516        let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
517        assert!(toolset.is_empty());
518    }
519
520    #[test]
521    fn invalid_yaml_returns_error() {
522        let result = OpenApiToolset::from_spec_str("not: [valid: yaml: {{");
523        assert!(result.is_err());
524    }
525
526    #[tokio::test]
527    async fn tool_definition_has_correct_fields() {
528        let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
529        let tools = toolset.into_tools();
530        let def = tools[0].definition("".into()).await;
531
532        assert_eq!(def.name, "getUser");
533        assert_eq!(def.description, "Get a user by id");
534    }
535
536    #[tokio::test]
537    async fn tool_definition_path_param_schema() {
538        let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
539        let tools = toolset.into_tools();
540        let def = tools[0].definition("".into()).await;
541
542        let props = def.parameters["properties"].as_object().unwrap();
543        assert!(props.contains_key("id"));
544
545        let required = def.parameters["required"].as_array().unwrap();
546        assert!(required.contains(&Value::String("id".into())));
547    }
548
549    #[tokio::test]
550    async fn tool_definition_query_param_not_required() {
551        let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
552        let tools = toolset.into_tools();
553        let list_tool = tools.iter().find(|t| t.name() == "listUsers").unwrap();
554        let def = list_tool.definition("".into()).await;
555
556        let props = def.parameters["properties"].as_object().unwrap();
557        assert!(props.contains_key("limit"));
558
559        let required = def.parameters["required"].as_array().unwrap();
560        assert!(!required.contains(&Value::String("limit".into())));
561    }
562
563    #[tokio::test]
564    async fn tool_definition_request_body_schema() {
565        let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
566        let tools = toolset.into_tools();
567        let create_tool = tools.iter().find(|t| t.name() == "createUser").unwrap();
568        let def = create_tool.definition("".into()).await;
569
570        let props = def.parameters["properties"].as_object().unwrap();
571        assert!(props.contains_key("body"));
572
573        let required = def.parameters["required"].as_array().unwrap();
574        assert!(required.contains(&Value::String("body".into())));
575    }
576
577    #[tokio::test]
578    async fn ref_parameters_are_resolved() {
579        let toolset = OpenApiToolset::from_spec_str(REF_SPEC).unwrap();
580        let tools = toolset.into_tools();
581        assert_eq!(tools.len(), 1);
582
583        let def = tools[0].definition("".into()).await;
584        let props = def.parameters["properties"].as_object().unwrap();
585        assert!(props.contains_key("id"));
586    }
587
588    #[tokio::test]
589    async fn tool_definition_header_param() {
590        let spec = r#"
591openapi: "3.0.0"
592info:
593  title: Test
594  version: "1.0"
595paths:
596  /data:
597    get:
598      operationId: getData
599      summary: Get data
600      parameters:
601        - name: X-Request-Id
602          in: header
603          required: false
604          schema:
605            type: string
606          description: Correlation ID
607      responses:
608        "200":
609          description: OK
610"#;
611        let toolset = OpenApiToolset::from_spec_str(spec).unwrap();
612        let tools = toolset.into_tools();
613        let def = tools[0].definition("".into()).await;
614
615        let props = def.parameters["properties"].as_object().unwrap();
616        assert!(props.contains_key("X-Request-Id"));
617    }
618
619    #[tokio::test]
620    async fn tool_call_with_invalid_json_returns_error() {
621        let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
622        let tools = toolset.into_tools();
623        let result = tools[0].call("not json".into()).await;
624        assert!(result.is_err());
625    }
626
627    #[tokio::test]
628    async fn hidden_context_excluded_from_schema() {
629        let toolset = OpenApiToolset::builder(MINIMAL_SPEC)
630            .hidden_context("id", "123")
631            .build()
632            .unwrap();
633        let tools = toolset.into_tools();
634        let def = tools[0].definition("".into()).await;
635
636        let props = def.parameters["properties"].as_object().unwrap();
637        assert!(
638            !props.contains_key("id"),
639            "hidden param should not appear in schema"
640        );
641
642        let required = def.parameters["required"].as_array().unwrap();
643        assert!(!required.contains(&Value::String("id".into())));
644    }
645
646    #[tokio::test]
647    async fn tools_with_context_excludes_from_schema() {
648        let toolset = OpenApiToolset::from_spec_str(MINIMAL_SPEC).unwrap();
649
650        // Without context, "id" is visible
651        let tools = toolset.tools_with_context(&HashMap::new());
652        let def = tools[0].definition("".into()).await;
653        let props = def.parameters["properties"].as_object().unwrap();
654        assert!(props.contains_key("id"));
655
656        // With context, "id" is hidden
657        let ctx = HashMap::from([("id".to_string(), "42".to_string())]);
658        let tools = toolset.tools_with_context(&ctx);
659        let def = tools[0].definition("".into()).await;
660        let props = def.parameters["properties"].as_object().unwrap();
661        assert!(!props.contains_key("id"));
662    }
663
664    #[test]
665    fn toolset_reusable_across_contexts() {
666        let toolset = OpenApiToolset::from_spec_str(MULTI_METHOD_SPEC).unwrap();
667
668        let ctx1 = HashMap::from([("id".to_string(), "1".to_string())]);
669        let ctx2 = HashMap::from([("id".to_string(), "2".to_string())]);
670
671        let tools1 = toolset.tools_with_context(&ctx1);
672        let tools2 = toolset.tools_with_context(&ctx2);
673
674        assert_eq!(tools1.len(), 4);
675        assert_eq!(tools2.len(), 4);
676    }
677
678    #[test]
679    fn context_preamble_generation() {
680        let ctx = HashMap::from([("user_id".to_string(), "123".to_string())]);
681        let preamble = OpenApiToolset::context_preamble(&ctx);
682        assert!(preamble.contains("user_id = 123"));
683        assert!(preamble.contains("Use these values"));
684    }
685
686    #[test]
687    fn context_preamble_empty() {
688        let preamble = OpenApiToolset::context_preamble(&HashMap::new());
689        assert!(preamble.is_empty());
690    }
691}