Skip to main content

post_cortex_daemon/daemon/
coerce.rs

1// Copyright (c) 2025 Julius ML
2// MIT License
3
4//! Type coercion and validation for MCP tool parameters.
5//!
6//! This module provides flexible type coercion to help AI agents interact
7//! with MCP tools correctly, even when they pass parameters with slightly
8//! incorrect types (e.g., numbers instead of strings).
9
10use serde::Deserialize;
11use serde_json::Value;
12
13/// Valid interaction type values accepted by MCP tools.
14pub use crate::daemon::validate::VALID_INTERACTION_TYPES;
15
16/// Coerce and validate a JSON value into the target type.
17///
18/// This function attempts to deserialize the value directly first (fast path).
19/// If that fails, it applies type coercion rules to fix common type mismatches
20/// before attempting deserialization again.
21///
22/// # Type Coercion Rules
23///
24/// - Number → String: Converts integers and floats to their string representation
25/// - Boolean → String: Converts "true"/"false" to strings
26/// - Object/Array → JSON String: Serializes nested structures as JSON strings
27///   (useful for `content` fields that expect stringified JSON)
28///
29/// # Example
30///
31/// ```rust
32/// use serde_json::json;
33///
34/// #[derive(Deserialize)]
35/// struct Request {
36///     session_id: String,
37///     count: String,
38/// }
39///
40/// let value = json!({
41///     "session_id": 123,  // Number instead of string
42///     "count": 42         // Number instead of string
43/// });
44///
45/// let req: Request = coerce_and_validate(value)?;
46/// assert_eq!(req.session_id, "123");
47/// assert_eq!(req.count, "42");
48/// ```
49pub fn coerce_and_validate<T: for<'de> Deserialize<'de>>(
50    value: Value,
51) -> Result<T, CoercionError> {
52    // Fast path: try direct deserialization first
53    if let Ok(result) = serde_json::from_value::<T>(value.clone()) {
54        return Ok(result);
55    }
56
57    // Slow path: apply coercions and try again
58    let coerced = apply_coercions(value)?;
59    serde_json::from_value(coerced).map_err(|e| {
60        CoercionError::new(
61            "Failed to deserialize parameter(s)",
62            e,
63            None,
64        )
65    })
66}
67
68/// Apply type coercion rules to a JSON value.
69///
70/// This recursively walks through the value and applies coercion rules
71/// to convert common type mismatches.
72fn apply_coercions(mut value: Value) -> Result<Value, CoercionError> {
73    if let Some(obj) = value.as_object_mut() {
74        for (_key, val) in obj.iter_mut() {
75            *val = coerce_value(val)?;
76        }
77    }
78    Ok(value)
79}
80
81/// Coerce a single value according to the coercion rules.
82fn coerce_value(val: &Value) -> Result<Value, CoercionError> {
83    match val {
84        // Number → String
85        // Handles cases like: {"session_id": 123} → {"session_id": "123"}
86        Value::Number(n) => {
87            if let Some(i) = n.as_i64() {
88                Ok(Value::String(i.to_string()))
89            } else if let Some(f) = n.as_f64() {
90                Ok(Value::String(f.to_string()))
91            } else {
92                Ok(val.clone())
93            }
94        }
95
96        // Boolean → String
97        // Handles cases like: {"enabled": true} → {"enabled": "true"}
98        Value::Bool(b) => Ok(Value::String(b.to_string())),
99
100        // Object → JSON String
101        // Handles cases where nested objects need to be stringified
102        // e.g., {"metadata": {"key": "value"}} → {"metadata": "{\"key\":\"value\"}"}
103        Value::Object(obj) => {
104            serde_json::to_string(obj)
105                .map(Value::String)
106                .map_err(|e| CoercionError::new(
107                    "Failed to serialize object to JSON string",
108                    e,
109                    Some(val.clone()),
110                ))
111        }
112
113        // Array → JSON String
114        // Similar to object coercion, stringifies arrays
115        Value::Array(arr) => {
116            serde_json::to_string(arr)
117                .map(Value::String)
118                .map_err(|e| CoercionError::new(
119                    "Failed to serialize array to JSON string",
120                    e,
121                    Some(val.clone()),
122                ))
123        }
124
125        // String, Null, etc. pass through unchanged
126        _ => Ok(val.clone()),
127    }
128}
129
130/// Generate recovery suggestions based on error message patterns.
131///
132/// Analyzes common error patterns and provides actionable suggestions
133/// for AI agents to fix their requests.
134///
135/// # Examples
136///
137/// ```rust
138/// let error_msg = "invalid type: string 'abc', expected u32";
139/// let suggestions = generate_recovery_suggestions(error_msg);
140/// // Returns: ["Ensure the value is a valid number", "Check for typos in the value"]
141/// ```
142pub fn generate_recovery_suggestions(
143    error_message: &str,
144    parameter_path: Option<&str>,
145    received_value: Option<&Value>,
146) -> Vec<String> {
147    let mut suggestions = Vec::new();
148
149    // Pattern: UUID validation errors
150    if error_message.contains("UUID") || error_message.contains("36-character") {
151        suggestions.push("Ensure the session_id is a valid 36-character UUID (e.g., '60c598e2-d602-4e07-a328-c458006d48c7')".to_string());
152        suggestions.push("Create a new session using the 'session' tool with action='create' to get a valid UUID".to_string());
153
154        if let Some(Value::String(s)) = received_value {
155            if s.len() != 36 {
156                suggestions.push(format!("Your session_id '{}' has {} characters, but UUIDs require exactly 36 characters with hyphens.", s, s.len()));
157            }
158        }
159    }
160
161    // Pattern: interaction_type validation
162    if error_message.contains("interaction_type") || error_message.contains("Unknown interaction type") {
163        suggestions.push(format!("Valid interaction_type values are: {}", VALID_INTERACTION_TYPES.join(", ")));
164        suggestions.push("Use lowercase with underscores, not CamelCase or spaces".to_string());
165        suggestions.push("Examples: ✅ 'decision_made' ❌ 'DecisionMade' ❌ 'made decision'".to_string());
166    }
167
168    // Pattern: content field errors
169    if error_message.contains("content") && error_message.contains("required") {
170        suggestions.push("For single update mode, provide both 'interaction_type' and 'content' parameters".to_string());
171        suggestions.push("For bulk updates, use 'updates' array instead".to_string());
172        suggestions.push("Content must be an object with key-value pairs".to_string());
173    }
174
175    // Pattern: Type coercion errors
176    if error_message.contains("invalid type") || error_message.contains("expected") {
177        if let Some(path) = parameter_path {
178            suggestions.push(format!("Parameter '{}' has an incorrect type", path));
179        }
180
181        // Check if value is a number when string expected
182        if let Some(Value::Number(n)) = received_value {
183            suggestions.push(format!("Convert the number {} to a string", n));
184        }
185
186        // Check if value is boolean when string expected
187        if let Some(Value::Bool(b)) = received_value {
188            suggestions.push(format!("Convert the boolean {} to a string ('{}')", b, b));
189        }
190    }
191
192    // Pattern: Missing required parameters
193    if error_message.contains("required") || error_message.contains("missing") {
194        suggestions.push("Check that all required parameters are included in your request".to_string());
195        suggestions.push("Review the tool schema to see which parameters are required vs optional".to_string());
196    }
197
198    // Pattern: Session not found
199    if error_message.contains("Session not found") || error_message.contains("session does not exist") {
200        suggestions.push("Create a new session using the 'session' tool with action='create'".to_string());
201        suggestions.push("Or use semantic_search to find existing sessions".to_string());
202    }
203
204    // Pattern: Array/structure errors
205    if error_message.contains("updates") && (error_message.contains("array") || error_message.contains("expected length")) {
206        suggestions.push("When using bulk mode, 'updates' must be an array of update objects".to_string());
207        suggestions.push("Each update in the array must have 'interaction_type' and 'content' fields".to_string());
208    }
209
210    // General fallback suggestions
211    if suggestions.is_empty() {
212        suggestions.push("Review the error message and check your parameter types and values".to_string());
213        suggestions.push("Use dry_run=true to validate your request without making changes".to_string());
214        suggestions.push("Check the tool documentation for the correct parameter format".to_string());
215    }
216
217    suggestions
218}
219
220/// Structured error type for coercion failures.
221///
222/// Provides rich error information to help AI agents understand
223/// what went wrong and how to fix it.
224#[derive(Debug, Clone)]
225pub struct CoercionError {
226    /// Human-readable error message
227    pub message: String,
228    /// Path to the parameter that failed (e.g., "session_id", "updates\[0\].interaction_type")
229    pub parameter_path: Option<String>,
230    /// Expected type description (e.g., "UUID string", "one of: qa, decision_made, ...")
231    pub expected_type: Option<String>,
232    /// The actual value that was received
233    pub received_value: Option<Value>,
234    /// Actionable hint for fixing the error
235    pub hint: Option<String>,
236}
237
238impl CoercionError {
239    /// Create a new coercion error.
240    pub fn new(
241        message: &str,
242        source_error: impl std::error::Error,
243        received_value: Option<Value>,
244    ) -> Self {
245        Self {
246            message: format!("{}: {}", message, source_error),
247            parameter_path: None,
248            expected_type: None,
249            received_value,
250            hint: None,
251        }
252    }
253
254    /// Set the parameter path for this error.
255    pub fn with_parameter_path(mut self, path: String) -> Self {
256        self.parameter_path = Some(path);
257        self
258    }
259
260    /// Set the expected type description for this error.
261    pub fn with_expected_type(mut self, type_desc: &str) -> Self {
262        self.expected_type = Some(type_desc.to_string());
263        self
264    }
265
266    /// Set a hint for fixing this error.
267    pub fn with_hint(mut self, hint: &str) -> Self {
268        self.hint = Some(hint.to_string());
269        self
270    }
271
272    /// Convert this error to an MCP error response.
273    ///
274    /// Creates a structured JSON error that agents can parse and understand.
275    /// Includes automatically generated recovery suggestions.
276    pub fn to_mcp_error(&self) -> rmcp::model::ErrorData {
277        let mut details = serde_json::json!({
278            "message": self.message,
279        });
280
281        if let Some(path) = &self.parameter_path {
282            details["parameter"] = serde_json::json!(path);
283        }
284
285        if let Some(expected) = &self.expected_type {
286            details["expectedType"] = serde_json::json!(expected);
287        }
288
289        if let Some(received) = &self.received_value {
290            details["receivedValue"] = received.clone();
291        }
292
293        if let Some(hint) = &self.hint {
294            details["hint"] = serde_json::json!(hint);
295        }
296
297        // Generate and include recovery suggestions
298        let suggestions = generate_recovery_suggestions(
299            &self.message,
300            self.parameter_path.as_deref(),
301            self.received_value.as_ref(),
302        );
303
304        if !suggestions.is_empty() {
305            details["suggestions"] = serde_json::json!(suggestions);
306        }
307
308        rmcp::model::ErrorData::invalid_params(
309            serde_json::to_string(&details).unwrap_or_default(),
310            None,
311        )
312    }
313}
314
315impl std::fmt::Display for CoercionError {
316    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        write!(f, "{}", self.message)?;
318        if let Some(path) = &self.parameter_path {
319            write!(f, " (parameter: {})", path)?;
320        }
321        if let Some(hint) = &self.hint {
322            write!(f, "\nHint: {}", hint)?;
323        }
324        Ok(())
325    }
326}
327
328impl std::error::Error for CoercionError {}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use serde_json::json;
334    use std::collections::HashMap;
335
336    #[test]
337    fn test_coerce_number_to_string() {
338        let value = json!({"session_id": 123});
339        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
340        assert_eq!(result.get("session_id"), Some(&"123".to_string()));
341    }
342
343    #[test]
344    fn test_coerce_float_to_string() {
345        // Use an arbitrary float that doesn't approximate `f64::consts::PI`.
346        let value = json!({"score": 2.71});
347        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
348        assert_eq!(result.get("score"), Some(&"2.71".to_string()));
349    }
350
351    #[test]
352    fn test_coerce_bool_to_string() {
353        let value = json!({"enabled": true});
354        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
355        assert_eq!(result.get("enabled"), Some(&"true".to_string()));
356    }
357
358    #[test]
359    fn test_coerce_object_to_json_string() {
360        let value = json!({"metadata": {"key": "value"}});
361        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
362        assert_eq!(result.get("metadata"), Some(&"{\"key\":\"value\"}".to_string()));
363    }
364
365    #[test]
366    fn test_coerce_array_to_json_string() {
367        let value = json!({"tags": ["tag1", "tag2"]});
368        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
369        assert_eq!(result.get("tags"), Some(&"[\"tag1\",\"tag2\"]".to_string()));
370    }
371
372    #[test]
373    fn test_fast_path_string_passes_through() {
374        let value = json!({"name": "test"});
375        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
376        assert_eq!(result.get("name"), Some(&"test".to_string()));
377    }
378
379    #[test]
380    fn test_coercion_error_with_path() {
381        let error = CoercionError::new(
382            "Test error",
383            std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid"),
384            Some(json!(123)),
385        )
386        .with_parameter_path("session_id".to_string())
387        .with_expected_type("UUID string")
388        .with_hint("Create a session first");
389
390        assert_eq!(error.parameter_path, Some("session_id".to_string()));
391        assert_eq!(error.expected_type, Some("UUID string".to_string()));
392        assert_eq!(error.hint, Some("Create a session first".to_string()));
393    }
394
395    #[test]
396    fn test_to_mcp_error_format() {
397        let error = CoercionError::new(
398            "Invalid parameter",
399            std::io::Error::new(std::io::ErrorKind::InvalidInput, "test"),
400            Some(json!(123)),
401        )
402        .with_parameter_path("session_id".to_string())
403        .with_expected_type("UUID string")
404        .with_hint("Use session tool to create");
405
406        let mcp_error = error.to_mcp_error();
407        // The to_mcp_error() function serializes the error details into the message
408        // The second parameter to invalid_params (data) is None in our implementation
409        // So we verify the message contains our structured data
410        let error_message = mcp_error.message;
411        assert!(error_message.contains("session_id"));
412
413        // Parse the message as JSON to verify structure
414        let error_data: serde_json::Value = serde_json::from_str(&error_message).unwrap();
415        assert_eq!(error_data["parameter"], "session_id");
416        assert_eq!(error_data["expectedType"], "UUID string");
417        assert_eq!(error_data["receivedValue"], 123);
418        assert_eq!(error_data["hint"], "Use session tool to create");
419    }
420
421    #[test]
422    fn test_recovery_suggestions_uuid_error() {
423        let suggestions = generate_recovery_suggestions(
424            "Invalid UUID format",
425            Some("session_id"),
426            Some(&json!("abc")),
427        );
428
429        assert!(suggestions.iter().any(|s| s.contains("36-character UUID")));
430        assert!(suggestions.iter().any(|s| s.contains("'session' tool")));
431    }
432
433    #[test]
434    fn test_recovery_suggestions_interaction_type_error() {
435        let suggestions = generate_recovery_suggestions(
436            "Unknown interaction type",
437            Some("interaction_type"),
438            Some(&json!("made_decision")),
439        );
440
441        assert!(suggestions.iter().any(|s| s.contains("decision_made")));
442        assert!(suggestions.iter().any(|s| s.contains("lowercase with underscores")));
443    }
444
445    #[test]
446    fn test_recovery_suggestions_type_error() {
447        let suggestions = generate_recovery_suggestions(
448            "invalid type: integer `123`, expected a string",
449            Some("session_id"),
450            Some(&json!(123)),
451        );
452
453        assert!(suggestions.iter().any(|s| s.contains("Convert the number 123")));
454    }
455
456    #[test]
457    fn test_recovery_suggestions_content_required() {
458        let suggestions = generate_recovery_suggestions(
459            "content is required",
460            Some("content"),
461            None,
462        );
463
464        assert!(suggestions.iter().any(|s| s.contains("interaction_type")));
465        assert!(suggestions.iter().any(|s| s.contains("bulk updates")));
466    }
467
468    #[test]
469    fn test_recovery_suggestions_session_not_found() {
470        let suggestions = generate_recovery_suggestions(
471            "Session not found",
472            None,
473            None,
474        );
475
476        assert!(suggestions.iter().any(|s| s.contains("'session' tool")));
477        assert!(suggestions.iter().any(|s| s.contains("semantic_search")));
478    }
479
480    #[test]
481    fn test_recovery_suggestions_includes_suggestions_in_mcp_error() {
482        let error = CoercionError::new(
483            "Invalid UUID format",
484            std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid"),
485            Some(json!("short-id")),
486        )
487        .with_parameter_path("session_id".to_string());
488
489        let mcp_error = error.to_mcp_error();
490        let error_data: serde_json::Value = serde_json::from_str(&mcp_error.message).unwrap();
491
492        // Verify suggestions array exists
493        assert!(error_data["suggestions"].is_array());
494        let suggestions = error_data["suggestions"].as_array().unwrap();
495        assert!(!suggestions.is_empty());
496    }
497
498    #[test]
499    fn test_recovery_suggestions_general_fallback() {
500        let suggestions = generate_recovery_suggestions(
501            "Some unknown error",
502            None,
503            None,
504        );
505
506        assert!(suggestions.iter().any(|s| s.contains("dry_run")));
507        assert!(suggestions.iter().any(|s| s.contains("parameter types")));
508    }
509}