Skip to main content

composio_sdk/utils/
schema.rs

1//! JSON Schema utilities for Composio SDK
2//!
3//! This module provides utilities for working with JSON schemas, including:
4//! - Type mapping between JSON Schema types and Rust types
5//! - Schema validation and conversion
6//! - Default value coercion
7//! - Reserved keyword handling
8//!
9//! # Example
10//!
11//! ```rust
12//! use composio_sdk::utils::schema::{JsonSchemaType, coerce_default_value};
13//! use serde_json::json;
14//!
15//! let schema = json!({
16//!     "type": "boolean",
17//!     "default": "true"
18//! });
19//!
20//! let coerced = coerce_default_value(&json!("true"), &schema);
21//! assert_eq!(coerced, json!(true));
22//! ```
23
24use serde_json::{Value, json};
25use std::collections::{HashMap, HashSet};
26
27/// JSON Schema primitive types
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum JsonSchemaType {
30    String,
31    Integer,
32    Number,
33    Boolean,
34    Array,
35    Object,
36    Null,
37}
38
39impl JsonSchemaType {
40    /// Parse a JSON Schema type string
41    pub fn from_str(s: &str) -> Option<Self> {
42        match s {
43            "string" => Some(Self::String),
44            "integer" => Some(Self::Integer),
45            "number" => Some(Self::Number),
46            "boolean" => Some(Self::Boolean),
47            "array" => Some(Self::Array),
48            "object" => Some(Self::Object),
49            "null" => Some(Self::Null),
50            _ => None,
51        }
52    }
53
54    /// Get the string representation
55    pub fn as_str(&self) -> &'static str {
56        match self {
57            Self::String => "string",
58            Self::Integer => "integer",
59            Self::Number => "number",
60            Self::Boolean => "boolean",
61            Self::Array => "array",
62            Self::Object => "object",
63            Self::Null => "null",
64        }
65    }
66
67    /// Check if this is a container type (array or object)
68    pub fn is_container(&self) -> bool {
69        matches!(self, Self::Array | Self::Object)
70    }
71
72    /// Get the Rust type name for this JSON Schema type
73    pub fn rust_type_name(&self) -> &'static str {
74        match self {
75            Self::String => "String",
76            Self::Integer => "i64",
77            Self::Number => "f64",
78            Self::Boolean => "bool",
79            Self::Array => "Vec<Value>",
80            Self::Object => "HashMap<String, Value>",
81            Self::Null => "Option<Value>",
82        }
83    }
84
85    /// Get the default/fallback value for this type
86    pub fn fallback_value(&self) -> Value {
87        match self {
88            Self::String => json!(""),
89            Self::Integer => json!(0),
90            Self::Number => json!(0.0),
91            Self::Boolean => json!(false),
92            Self::Array => json!([]),
93            Self::Object => json!({}),
94            Self::Null => Value::Null,
95        }
96    }
97}
98
99/// Rust reserved keywords that need special handling
100const RUST_KEYWORDS: &[&str] = &[
101    "as", "break", "const", "continue", "crate", "else", "enum", "extern",
102    "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod",
103    "move", "mut", "pub", "ref", "return", "self", "Self", "static", "struct",
104    "super", "trait", "true", "type", "unsafe", "use", "where", "while",
105    "async", "await", "dyn", "abstract", "become", "box", "do", "final",
106    "macro", "override", "priv", "typeof", "unsized", "virtual", "yield",
107    "try",
108];
109
110/// Reserved Pydantic field names that need special handling
111const RESERVED_FIELD_NAMES: &[&str] = &["validate"];
112
113/// Marker for nested object keyword substitutions
114const OBJ_MARKER: &str = "-_object_-";
115
116/// Make a field name safe by appending a suffix
117///
118/// For Rust keywords, we append `_field` to make them valid identifiers.
119/// For Pydantic reserved names, we append `_` for compatibility.
120pub fn make_safe_field_name(name: &str) -> String {
121    if RUST_KEYWORDS.contains(&name) {
122        format!("{}_field", name)
123    } else if RESERVED_FIELD_NAMES.contains(&name) {
124        format!("{}_", name)
125    } else {
126        name.to_string()
127    }
128}
129
130/// Check if a name is a Rust keyword
131pub fn is_rust_keyword(name: &str) -> bool {
132    RUST_KEYWORDS.contains(&name)
133}
134
135/// Check if a name is a reserved field name
136pub fn is_reserved_field_name(name: &str) -> bool {
137    RESERVED_FIELD_NAMES.contains(&name)
138}
139
140/// Substitute reserved keywords in a JSON schema
141///
142/// Returns a tuple of (modified_schema, keyword_mappings) where:
143/// - modified_schema has safe property names
144/// - keyword_mappings maps safe names back to original names
145///
146/// # Example
147///
148/// ```rust
149/// use composio_sdk::utils::schema::substitute_reserved_keywords;
150/// use serde_json::json;
151///
152/// let schema = json!({
153///     "properties": {
154///         "type": {"type": "string"},
155///         "match": {"type": "integer"}
156///     },
157///     "required": ["type"]
158/// });
159///
160/// let (safe_schema, mappings) = substitute_reserved_keywords(&schema);
161/// // safe_schema will have "type_field" and "match_field" as property names
162/// ```
163pub fn substitute_reserved_keywords(schema: &Value) -> (Value, HashMap<String, String>) {
164    let mut mappings = HashMap::new();
165    
166    if !schema.is_object() {
167        return (schema.clone(), mappings);
168    }
169
170    let mut result = schema.clone();
171    
172    if let Some(properties) = result.get("properties").and_then(|p| p.as_object()) {
173        let mut new_properties = serde_json::Map::new();
174        
175        for (prop_name, prop_value) in properties {
176            if is_rust_keyword(prop_name) || is_reserved_field_name(prop_name) {
177                let safe_name = make_safe_field_name(prop_name);
178                
179                // Recursively handle nested objects
180                let safe_value = if prop_value.get("type") == Some(&json!("object")) {
181                    let (nested_schema, nested_map) = substitute_reserved_keywords(prop_value);
182                    if !nested_map.is_empty() {
183                        mappings.insert(format!("{}{}", safe_name, OBJ_MARKER), 
184                                      serde_json::to_string(&nested_map).unwrap_or_default());
185                    }
186                    nested_schema
187                } else {
188                    prop_value.clone()
189                };
190                
191                new_properties.insert(safe_name.clone(), safe_value);
192                mappings.insert(safe_name, prop_name.clone());
193            } else {
194                new_properties.insert(prop_name.clone(), prop_value.clone());
195            }
196        }
197        
198        result["properties"] = Value::Object(new_properties);
199        
200        // Update required array
201        if let Some(required) = result.get("required").and_then(|r| r.as_array()) {
202            let reverse_map: HashMap<_, _> = mappings.iter()
203                .filter(|(k, _)| !k.ends_with(OBJ_MARKER))
204                .map(|(k, v)| (v.as_str(), k.as_str()))
205                .collect();
206            
207            let new_required: Vec<Value> = required.iter()
208                .map(|r| {
209                    if let Some(s) = r.as_str() {
210                        if let Some(safe) = reverse_map.get(s) {
211                            json!(safe)
212                        } else {
213                            r.clone()
214                        }
215                    } else {
216                        r.clone()
217                    }
218                })
219                .collect();
220            
221            result["required"] = json!(new_required);
222        }
223    }
224    
225    (result, mappings)
226}
227
228/// Reinstate reserved keywords in a request object
229///
230/// Reverses the substitution performed by `substitute_reserved_keywords`.
231/// Modifies the request in-place and returns it.
232///
233/// # Example
234///
235/// ```rust
236/// use composio_sdk::utils::schema::reinstate_reserved_keywords;
237/// use serde_json::json;
238/// use std::collections::HashMap;
239///
240/// let mut request = json!({
241///     "type_field": "example",
242///     "match_field": 42
243/// });
244///
245/// let mut mappings = HashMap::new();
246/// mappings.insert("type_field".to_string(), "type".to_string());
247/// mappings.insert("match_field".to_string(), "match".to_string());
248///
249/// let result = reinstate_reserved_keywords(&mut request, &mappings);
250/// // result will have "type" and "match" as keys
251/// ```
252pub fn reinstate_reserved_keywords(
253    request: &mut Value,
254    mappings: &HashMap<String, String>,
255) -> Value {
256    if !request.is_object() {
257        return request.clone();
258    }
259
260    let mut sorted_keys: Vec<_> = mappings.keys().collect();
261    sorted_keys.sort_by(|a, b| b.cmp(a)); // Reverse order
262
263    let obj = request.as_object_mut().unwrap();
264    let mut updates = Vec::new();
265
266    for clean_key in sorted_keys {
267        let mut subkeys = None;
268        let actual_key = if clean_key.ends_with(OBJ_MARKER) {
269            if let Some(nested_json) = mappings.get(clean_key.as_str()) {
270                if let Ok(nested_map) = serde_json::from_str::<HashMap<String, String>>(nested_json) {
271                    subkeys = Some(nested_map);
272                }
273            }
274            clean_key.trim_end_matches(OBJ_MARKER)
275        } else {
276            clean_key.as_str()
277        };
278
279        if let Some(mut value) = obj.remove(actual_key) {
280            if let Some(nested_mappings) = subkeys {
281                value = reinstate_reserved_keywords(&mut value, &nested_mappings);
282            }
283            
284            if let Some(original_key) = mappings.get(clean_key.as_str()) {
285                updates.push((original_key.clone(), value));
286            }
287        }
288    }
289
290    for (key, value) in updates {
291        obj.insert(key, value);
292    }
293
294    Value::Object(obj.clone())
295}
296
297/// Coerce a default value to match the expected type from JSON schema
298///
299/// Handles common mismatches where string defaults should be boolean/int/float.
300/// This fixes issues where API returns stringified defaults like "true" instead of true.
301///
302/// Coercion precedence: boolean > integer > float
303///
304/// # Example
305///
306/// ```rust
307/// use composio_sdk::utils::schema::coerce_default_value;
308/// use serde_json::json;
309///
310/// let schema = json!({"type": "boolean"});
311/// let default = json!("true");
312/// let coerced = coerce_default_value(&default, &schema);
313/// assert_eq!(coerced, json!(true));
314/// ```
315pub fn coerce_default_value(default: &Value, schema: &Value) -> Value {
316    // Only coerce string values
317    let default_str = match default.as_str() {
318        Some(s) => s,
319        None => return default.clone(),
320    };
321
322    // Collect expected types from schema
323    let mut expected_types = HashSet::new();
324
325    // Direct type
326    if let Some(type_str) = schema.get("type").and_then(|t| t.as_str()) {
327        if let Some(schema_type) = JsonSchemaType::from_str(type_str) {
328            expected_types.insert(schema_type);
329        }
330    }
331
332    // Types from combiners (anyOf, oneOf, allOf)
333    for combiner in &["anyOf", "oneOf", "allOf"] {
334        if let Some(options) = schema.get(combiner).and_then(|o| o.as_array()) {
335            for option in options {
336                if let Some(type_str) = option.get("type").and_then(|t| t.as_str()) {
337                    if let Some(schema_type) = JsonSchemaType::from_str(type_str) {
338                        expected_types.insert(schema_type);
339                    }
340                }
341            }
342        }
343    }
344
345    // If string is expected, no coercion needed
346    if expected_types.contains(&JsonSchemaType::String) {
347        return default.clone();
348    }
349
350    // Boolean coercion (takes precedence)
351    if expected_types.contains(&JsonSchemaType::Boolean) {
352        let lower = default_str.to_lowercase();
353        if matches!(lower.as_str(), "true" | "yes" | "1") {
354            return json!(true);
355        }
356        if matches!(lower.as_str(), "false" | "no" | "0") {
357            return json!(false);
358        }
359    }
360
361    // Integer coercion
362    if expected_types.contains(&JsonSchemaType::Integer) {
363        if let Ok(i) = default_str.parse::<i64>() {
364            return json!(i);
365        }
366    }
367
368    // Float coercion
369    if expected_types.contains(&JsonSchemaType::Number) {
370        if let Ok(f) = default_str.parse::<f64>() {
371            return json!(f);
372        }
373    }
374
375    default.clone()
376}
377
378/// Generate a unique request ID
379///
380/// Uses UUID v4 for generating unique identifiers.
381///
382/// # Example
383///
384/// ```rust
385/// use composio_sdk::utils::schema::generate_request_id;
386///
387/// let request_id = generate_request_id();
388/// assert_eq!(request_id.len(), 36); // UUID format
389/// ```
390pub fn generate_request_id() -> String {
391    uuid::Uuid::new_v4().to_string()
392}
393
394/// Generate a UUID v4 string
395///
396/// Alias for `generate_request_id()` for compatibility with Python SDK.
397///
398/// # Example
399///
400/// ```rust
401/// use composio_sdk::utils::schema::generate_uuid;
402///
403/// let uuid = generate_uuid();
404/// assert_eq!(uuid.len(), 36);
405/// ```
406pub fn generate_uuid() -> String {
407    generate_request_id()
408}
409
410/// Generate a short ID (8 characters) from a UUID
411///
412/// Returns the first 8 characters of a UUID with dashes removed.
413/// Useful for generating compact identifiers.
414///
415/// # Example
416///
417/// ```rust
418/// use composio_sdk::utils::schema::generate_short_id;
419///
420/// let short_id = generate_short_id();
421/// assert_eq!(short_id.len(), 8);
422/// assert!(!short_id.contains('-'));
423/// ```
424pub fn generate_short_id() -> String {
425    generate_uuid()
426        .chars()
427        .filter(|c| *c != '-')
428        .take(8)
429        .collect()
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_json_schema_type_from_str() {
438        assert_eq!(JsonSchemaType::from_str("string"), Some(JsonSchemaType::String));
439        assert_eq!(JsonSchemaType::from_str("integer"), Some(JsonSchemaType::Integer));
440        assert_eq!(JsonSchemaType::from_str("boolean"), Some(JsonSchemaType::Boolean));
441        assert_eq!(JsonSchemaType::from_str("invalid"), None);
442    }
443
444    #[test]
445    fn test_json_schema_type_as_str() {
446        assert_eq!(JsonSchemaType::String.as_str(), "string");
447        assert_eq!(JsonSchemaType::Integer.as_str(), "integer");
448        assert_eq!(JsonSchemaType::Boolean.as_str(), "boolean");
449    }
450
451    #[test]
452    fn test_json_schema_type_is_container() {
453        assert!(JsonSchemaType::Array.is_container());
454        assert!(JsonSchemaType::Object.is_container());
455        assert!(!JsonSchemaType::String.is_container());
456        assert!(!JsonSchemaType::Integer.is_container());
457    }
458
459    #[test]
460    fn test_json_schema_type_fallback_value() {
461        assert_eq!(JsonSchemaType::String.fallback_value(), json!(""));
462        assert_eq!(JsonSchemaType::Integer.fallback_value(), json!(0));
463        assert_eq!(JsonSchemaType::Boolean.fallback_value(), json!(false));
464        assert_eq!(JsonSchemaType::Array.fallback_value(), json!([]));
465        assert_eq!(JsonSchemaType::Object.fallback_value(), json!({}));
466    }
467
468    #[test]
469    fn test_is_rust_keyword() {
470        assert!(is_rust_keyword("type"));
471        assert!(is_rust_keyword("match"));
472        assert!(is_rust_keyword("impl"));
473        assert!(!is_rust_keyword("name"));
474        assert!(!is_rust_keyword("value"));
475    }
476
477    #[test]
478    fn test_make_safe_field_name() {
479        assert_eq!(make_safe_field_name("type"), "type_field");
480        assert_eq!(make_safe_field_name("match"), "match_field");
481        assert_eq!(make_safe_field_name("validate"), "validate_");
482        assert_eq!(make_safe_field_name("normal"), "normal");
483    }
484
485    #[test]
486    fn test_substitute_reserved_keywords() {
487        let schema = json!({
488            "properties": {
489                "type": {"type": "string"},
490                "match": {"type": "integer"},
491                "normal": {"type": "boolean"}
492            },
493            "required": ["type"]
494        });
495
496        let (safe_schema, mappings) = substitute_reserved_keywords(&schema);
497        
498        assert!(safe_schema["properties"].get("type_field").is_some());
499        assert!(safe_schema["properties"].get("match_field").is_some());
500        assert!(safe_schema["properties"].get("normal").is_some());
501        
502        assert_eq!(mappings.get("type_field"), Some(&"type".to_string()));
503        assert_eq!(mappings.get("match_field"), Some(&"match".to_string()));
504        
505        let required = safe_schema["required"].as_array().unwrap();
506        assert!(required.contains(&json!("type_field")));
507    }
508
509    #[test]
510    fn test_reinstate_reserved_keywords() {
511        let mut request = json!({
512            "type_field": "example",
513            "match_field": 42,
514            "normal": true
515        });
516
517        let mut mappings = HashMap::new();
518        mappings.insert("type_field".to_string(), "type".to_string());
519        mappings.insert("match_field".to_string(), "match".to_string());
520
521        let result = reinstate_reserved_keywords(&mut request, &mappings);
522        
523        assert_eq!(result.get("type"), Some(&json!("example")));
524        assert_eq!(result.get("match"), Some(&json!(42)));
525        assert_eq!(result.get("normal"), Some(&json!(true)));
526        assert!(result.get("type_field").is_none());
527        assert!(result.get("match_field").is_none());
528    }
529
530    #[test]
531    fn test_coerce_default_value_boolean() {
532        let schema = json!({"type": "boolean"});
533        
534        assert_eq!(coerce_default_value(&json!("true"), &schema), json!(true));
535        assert_eq!(coerce_default_value(&json!("false"), &schema), json!(false));
536        assert_eq!(coerce_default_value(&json!("yes"), &schema), json!(true));
537        assert_eq!(coerce_default_value(&json!("no"), &schema), json!(false));
538        assert_eq!(coerce_default_value(&json!("1"), &schema), json!(true));
539        assert_eq!(coerce_default_value(&json!("0"), &schema), json!(false));
540    }
541
542    #[test]
543    fn test_coerce_default_value_integer() {
544        let schema = json!({"type": "integer"});
545        
546        assert_eq!(coerce_default_value(&json!("42"), &schema), json!(42));
547        assert_eq!(coerce_default_value(&json!("-10"), &schema), json!(-10));
548        assert_eq!(coerce_default_value(&json!("0"), &schema), json!(0));
549    }
550
551    #[test]
552    fn test_coerce_default_value_number() {
553        let schema = json!({"type": "number"});
554        
555        assert_eq!(coerce_default_value(&json!("3.14"), &schema), json!(3.14));
556        assert_eq!(coerce_default_value(&json!("-2.5"), &schema), json!(-2.5));
557    }
558
559    #[test]
560    fn test_coerce_default_value_string_no_coercion() {
561        let schema = json!({"type": "string"});
562        
563        assert_eq!(coerce_default_value(&json!("true"), &schema), json!("true"));
564        assert_eq!(coerce_default_value(&json!("42"), &schema), json!("42"));
565    }
566
567    #[test]
568    fn test_coerce_default_value_with_combiners() {
569        let schema = json!({
570            "anyOf": [
571                {"type": "boolean"},
572                {"type": "integer"}
573            ]
574        });
575        
576        // Boolean takes precedence over integer
577        assert_eq!(coerce_default_value(&json!("true"), &schema), json!(true));
578        assert_eq!(coerce_default_value(&json!("1"), &schema), json!(true));
579        assert_eq!(coerce_default_value(&json!("42"), &schema), json!(42));
580    }
581
582    #[test]
583    fn test_coerce_default_value_non_string() {
584        let schema = json!({"type": "boolean"});
585        
586        // Non-string values are returned as-is
587        assert_eq!(coerce_default_value(&json!(true), &schema), json!(true));
588        assert_eq!(coerce_default_value(&json!(42), &schema), json!(42));
589        assert_eq!(coerce_default_value(&Value::Null, &schema), Value::Null);
590    }
591
592    #[test]
593    fn test_generate_request_id() {
594        let id1 = generate_request_id();
595        let id2 = generate_request_id();
596        
597        assert_eq!(id1.len(), 36); // UUID v4 format
598        assert_eq!(id2.len(), 36);
599        assert_ne!(id1, id2); // Should be unique
600    }
601
602    #[test]
603    fn test_generate_uuid() {
604        let uuid1 = generate_uuid();
605        let uuid2 = generate_uuid();
606        
607        assert_eq!(uuid1.len(), 36);
608        assert_eq!(uuid2.len(), 36);
609        assert_ne!(uuid1, uuid2);
610    }
611
612    #[test]
613    fn test_generate_short_id() {
614        let short1 = generate_short_id();
615        let short2 = generate_short_id();
616        
617        assert_eq!(short1.len(), 8);
618        assert_eq!(short2.len(), 8);
619        assert!(!short1.contains('-'));
620        assert!(!short2.contains('-'));
621        assert_ne!(short1, short2);
622    }
623}