Skip to main content

post_cortex_daemon/daemon/
coerce.rs

1// Copyright (c) 2025 Julius ML
2// MIT License
3
4//! Type coercion and validation for MCP tool parameters.
5//!
6//! This module provides flexible type coercion to help AI agents interact
7//! with MCP tools correctly, even when they pass parameters with slightly
8//! incorrect types (e.g., numbers instead of strings).
9
10use serde::Deserialize;
11use serde_json::Value;
12
13/// Valid interaction type values accepted by MCP tools.
14pub use crate::daemon::validate::VALID_INTERACTION_TYPES;
15
16/// Coerce and validate a JSON value into the target type.
17///
18/// This function attempts to deserialize the value directly first (fast path).
19/// If that fails, it applies type coercion rules to fix common type mismatches
20/// before attempting deserialization again.
21///
22/// # Type Coercion Rules
23///
24/// - Number → String: Converts integers and floats to their string representation
25/// - Boolean → String: Converts "true"/"false" to strings
26/// - Object/Array → JSON String: Serializes nested structures as JSON strings
27///   (useful for `content` fields that expect stringified JSON)
28///
29/// # Example
30///
31/// ```rust,ignore
32/// use serde_json::json;
33///
34/// #[derive(Deserialize)]
35/// struct Request {
36///     session_id: String,
37///     count: String,
38/// }
39///
40/// let value = json!({
41///     "session_id": 123,  // Number instead of string
42///     "count": 42         // Number instead of string
43/// });
44///
45/// let req: Request = coerce_and_validate(value)?;
46/// assert_eq!(req.session_id, "123");
47/// assert_eq!(req.count, "42");
48/// ```
49pub fn coerce_and_validate<T: for<'de> Deserialize<'de>>(value: Value) -> Result<T, CoercionError> {
50    // Fast path: try direct deserialization first
51    if let Ok(result) = serde_json::from_value::<T>(value.clone()) {
52        return Ok(result);
53    }
54
55    // Slow path: apply coercions and try again
56    let coerced = apply_coercions(value)?;
57    serde_json::from_value(coerced)
58        .map_err(|e| CoercionError::new("Failed to deserialize parameter(s)", e, None))
59}
60
61/// Apply type coercion rules to a JSON value.
62///
63/// This recursively walks through the value and applies coercion rules
64/// to convert common type mismatches.
65fn apply_coercions(mut value: Value) -> Result<Value, CoercionError> {
66    if let Some(obj) = value.as_object_mut() {
67        for (_key, val) in obj.iter_mut() {
68            *val = coerce_value(val)?;
69        }
70    }
71    Ok(value)
72}
73
74/// Coerce a single value according to the coercion rules.
75fn coerce_value(val: &Value) -> Result<Value, CoercionError> {
76    match val {
77        // Number → String
78        // Handles cases like: {"session_id": 123} → {"session_id": "123"}
79        Value::Number(n) => {
80            if let Some(i) = n.as_i64() {
81                Ok(Value::String(i.to_string()))
82            } else if let Some(f) = n.as_f64() {
83                Ok(Value::String(f.to_string()))
84            } else {
85                Ok(val.clone())
86            }
87        }
88
89        // Boolean → String
90        // Handles cases like: {"enabled": true} → {"enabled": "true"}
91        Value::Bool(b) => Ok(Value::String(b.to_string())),
92
93        // Object → JSON String
94        // Handles cases where nested objects need to be stringified
95        // e.g., {"metadata": {"key": "value"}} → {"metadata": "{\"key\":\"value\"}"}
96        Value::Object(obj) => serde_json::to_string(obj).map(Value::String).map_err(|e| {
97            CoercionError::new(
98                "Failed to serialize object to JSON string",
99                e,
100                Some(val.clone()),
101            )
102        }),
103
104        // Array → JSON String
105        // Similar to object coercion, stringifies arrays
106        Value::Array(arr) => serde_json::to_string(arr).map(Value::String).map_err(|e| {
107            CoercionError::new(
108                "Failed to serialize array to JSON string",
109                e,
110                Some(val.clone()),
111            )
112        }),
113
114        // String, Null, etc. pass through unchanged
115        _ => Ok(val.clone()),
116    }
117}
118
119/// Generate recovery suggestions based on error message patterns.
120///
121/// Analyzes common error patterns and provides actionable suggestions
122/// for AI agents to fix their requests.
123///
124/// # Examples
125///
126/// ```rust,ignore
127/// let error_msg = "invalid type: string 'abc', expected u32";
128/// let suggestions = generate_recovery_suggestions(error_msg);
129/// // Returns: ["Ensure the value is a valid number", "Check for typos in the value"]
130/// ```
131pub fn generate_recovery_suggestions(
132    error_message: &str,
133    parameter_path: Option<&str>,
134    received_value: Option<&Value>,
135) -> Vec<String> {
136    let mut suggestions = Vec::new();
137
138    // Pattern: UUID validation errors
139    if error_message.contains("UUID") || error_message.contains("36-character") {
140        suggestions.push("Ensure the session_id is a valid 36-character UUID (e.g., '60c598e2-d602-4e07-a328-c458006d48c7')".to_string());
141        suggestions.push("Create a new session using the 'session' tool with action='create' to get a valid UUID".to_string());
142
143        if let Some(Value::String(s)) = received_value
144            && s.len() != 36
145        {
146            suggestions.push(format!("Your session_id '{}' has {} characters, but UUIDs require exactly 36 characters with hyphens.", s, s.len()));
147        }
148    }
149
150    // Pattern: interaction_type validation
151    if error_message.contains("interaction_type")
152        || error_message.contains("Unknown interaction type")
153    {
154        suggestions.push(format!(
155            "Valid interaction_type values are: {}",
156            VALID_INTERACTION_TYPES.join(", ")
157        ));
158        suggestions.push("Use lowercase with underscores, not CamelCase or spaces".to_string());
159        suggestions
160            .push("Examples: ✅ 'decision_made' ❌ 'DecisionMade' ❌ 'made decision'".to_string());
161    }
162
163    // Pattern: content field errors
164    if error_message.contains("content") && error_message.contains("required") {
165        suggestions.push(
166            "For single update mode, provide both 'interaction_type' and 'content' parameters"
167                .to_string(),
168        );
169        suggestions.push("For bulk updates, use 'updates' array instead".to_string());
170        suggestions.push("Content must be an object with key-value pairs".to_string());
171    }
172
173    // Pattern: Type coercion errors
174    if error_message.contains("invalid type") || error_message.contains("expected") {
175        if let Some(path) = parameter_path {
176            suggestions.push(format!("Parameter '{}' has an incorrect type", path));
177        }
178
179        // Check if value is a number when string expected
180        if let Some(Value::Number(n)) = received_value {
181            suggestions.push(format!("Convert the number {} to a string", n));
182        }
183
184        // Check if value is boolean when string expected
185        if let Some(Value::Bool(b)) = received_value {
186            suggestions.push(format!("Convert the boolean {} to a string ('{}')", b, b));
187        }
188    }
189
190    // Pattern: Missing required parameters
191    if error_message.contains("required") || error_message.contains("missing") {
192        suggestions
193            .push("Check that all required parameters are included in your request".to_string());
194        suggestions.push(
195            "Review the tool schema to see which parameters are required vs optional".to_string(),
196        );
197    }
198
199    // Pattern: Session not found
200    if error_message.contains("Session not found")
201        || error_message.contains("session does not exist")
202    {
203        suggestions
204            .push("Create a new session using the 'session' tool with action='create'".to_string());
205        suggestions.push("Or use semantic_search to find existing sessions".to_string());
206    }
207
208    // Pattern: Array/structure errors
209    if error_message.contains("updates")
210        && (error_message.contains("array") || error_message.contains("expected length"))
211    {
212        suggestions
213            .push("When using bulk mode, 'updates' must be an array of update objects".to_string());
214        suggestions.push(
215            "Each update in the array must have 'interaction_type' and 'content' fields"
216                .to_string(),
217        );
218    }
219
220    // General fallback suggestions
221    if suggestions.is_empty() {
222        suggestions
223            .push("Review the error message and check your parameter types and values".to_string());
224        suggestions
225            .push("Use dry_run=true to validate your request without making changes".to_string());
226        suggestions
227            .push("Check the tool documentation for the correct parameter format".to_string());
228    }
229
230    suggestions
231}
232
233/// Structured error type for coercion failures.
234///
235/// Provides rich error information to help AI agents understand
236/// what went wrong and how to fix it.
237#[derive(Debug, Clone)]
238pub struct CoercionError {
239    /// Human-readable error message
240    pub message: String,
241    /// Path to the parameter that failed (e.g., "session_id", "updates\[0\].interaction_type")
242    pub parameter_path: Option<String>,
243    /// Expected type description (e.g., "UUID string", "one of: qa, decision_made, ...")
244    pub expected_type: Option<String>,
245    /// The actual value that was received
246    pub received_value: Option<Value>,
247    /// Actionable hint for fixing the error
248    pub hint: Option<String>,
249}
250
251impl CoercionError {
252    /// Create a new coercion error.
253    pub fn new(
254        message: &str,
255        source_error: impl std::error::Error,
256        received_value: Option<Value>,
257    ) -> Self {
258        Self {
259            message: format!("{}: {}", message, source_error),
260            parameter_path: None,
261            expected_type: None,
262            received_value,
263            hint: None,
264        }
265    }
266
267    /// Set the parameter path for this error.
268    pub fn with_parameter_path(mut self, path: String) -> Self {
269        self.parameter_path = Some(path);
270        self
271    }
272
273    /// Set the expected type description for this error.
274    pub fn with_expected_type(mut self, type_desc: &str) -> Self {
275        self.expected_type = Some(type_desc.to_string());
276        self
277    }
278
279    /// Set a hint for fixing this error.
280    pub fn with_hint(mut self, hint: &str) -> Self {
281        self.hint = Some(hint.to_string());
282        self
283    }
284
285    /// Convert this error to an MCP error response.
286    ///
287    /// Creates a structured JSON error that agents can parse and understand.
288    /// Includes automatically generated recovery suggestions.
289    pub fn to_mcp_error(&self) -> rmcp::model::ErrorData {
290        let mut details = serde_json::json!({
291            "message": self.message,
292        });
293
294        if let Some(path) = &self.parameter_path {
295            details["parameter"] = serde_json::json!(path);
296        }
297
298        if let Some(expected) = &self.expected_type {
299            details["expectedType"] = serde_json::json!(expected);
300        }
301
302        if let Some(received) = &self.received_value {
303            details["receivedValue"] = received.clone();
304        }
305
306        if let Some(hint) = &self.hint {
307            details["hint"] = serde_json::json!(hint);
308        }
309
310        // Generate and include recovery suggestions
311        let suggestions = generate_recovery_suggestions(
312            &self.message,
313            self.parameter_path.as_deref(),
314            self.received_value.as_ref(),
315        );
316
317        if !suggestions.is_empty() {
318            details["suggestions"] = serde_json::json!(suggestions);
319        }
320
321        rmcp::model::ErrorData::invalid_params(
322            serde_json::to_string(&details).unwrap_or_default(),
323            None,
324        )
325    }
326}
327
328impl std::fmt::Display for CoercionError {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        write!(f, "{}", self.message)?;
331        if let Some(path) = &self.parameter_path {
332            write!(f, " (parameter: {})", path)?;
333        }
334        if let Some(hint) = &self.hint {
335            write!(f, "\nHint: {}", hint)?;
336        }
337        Ok(())
338    }
339}
340
341impl std::error::Error for CoercionError {}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use serde_json::json;
347    use std::collections::HashMap;
348
349    #[test]
350    fn test_coerce_number_to_string() {
351        let value = json!({"session_id": 123});
352        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
353        assert_eq!(result.get("session_id"), Some(&"123".to_string()));
354    }
355
356    #[test]
357    fn test_coerce_float_to_string() {
358        // Use an arbitrary float that doesn't approximate `f64::consts::PI`.
359        let value = json!({"score": 2.71});
360        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
361        assert_eq!(result.get("score"), Some(&"2.71".to_string()));
362    }
363
364    #[test]
365    fn test_coerce_bool_to_string() {
366        let value = json!({"enabled": true});
367        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
368        assert_eq!(result.get("enabled"), Some(&"true".to_string()));
369    }
370
371    #[test]
372    fn test_coerce_object_to_json_string() {
373        let value = json!({"metadata": {"key": "value"}});
374        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
375        assert_eq!(
376            result.get("metadata"),
377            Some(&"{\"key\":\"value\"}".to_string())
378        );
379    }
380
381    #[test]
382    fn test_coerce_array_to_json_string() {
383        let value = json!({"tags": ["tag1", "tag2"]});
384        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
385        assert_eq!(result.get("tags"), Some(&"[\"tag1\",\"tag2\"]".to_string()));
386    }
387
388    #[test]
389    fn test_fast_path_string_passes_through() {
390        let value = json!({"name": "test"});
391        let result: HashMap<String, String> = coerce_and_validate(value).unwrap();
392        assert_eq!(result.get("name"), Some(&"test".to_string()));
393    }
394
395    #[test]
396    fn test_coercion_error_with_path() {
397        let error = CoercionError::new(
398            "Test error",
399            std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid"),
400            Some(json!(123)),
401        )
402        .with_parameter_path("session_id".to_string())
403        .with_expected_type("UUID string")
404        .with_hint("Create a session first");
405
406        assert_eq!(error.parameter_path, Some("session_id".to_string()));
407        assert_eq!(error.expected_type, Some("UUID string".to_string()));
408        assert_eq!(error.hint, Some("Create a session first".to_string()));
409    }
410
411    #[test]
412    fn test_to_mcp_error_format() {
413        let error = CoercionError::new(
414            "Invalid parameter",
415            std::io::Error::new(std::io::ErrorKind::InvalidInput, "test"),
416            Some(json!(123)),
417        )
418        .with_parameter_path("session_id".to_string())
419        .with_expected_type("UUID string")
420        .with_hint("Use session tool to create");
421
422        let mcp_error = error.to_mcp_error();
423        // The to_mcp_error() function serializes the error details into the message
424        // The second parameter to invalid_params (data) is None in our implementation
425        // So we verify the message contains our structured data
426        let error_message = mcp_error.message;
427        assert!(error_message.contains("session_id"));
428
429        // Parse the message as JSON to verify structure
430        let error_data: serde_json::Value = serde_json::from_str(&error_message).unwrap();
431        assert_eq!(error_data["parameter"], "session_id");
432        assert_eq!(error_data["expectedType"], "UUID string");
433        assert_eq!(error_data["receivedValue"], 123);
434        assert_eq!(error_data["hint"], "Use session tool to create");
435    }
436
437    #[test]
438    fn test_recovery_suggestions_uuid_error() {
439        let suggestions = generate_recovery_suggestions(
440            "Invalid UUID format",
441            Some("session_id"),
442            Some(&json!("abc")),
443        );
444
445        assert!(suggestions.iter().any(|s| s.contains("36-character UUID")));
446        assert!(suggestions.iter().any(|s| s.contains("'session' tool")));
447    }
448
449    #[test]
450    fn test_recovery_suggestions_interaction_type_error() {
451        let suggestions = generate_recovery_suggestions(
452            "Unknown interaction type",
453            Some("interaction_type"),
454            Some(&json!("made_decision")),
455        );
456
457        assert!(suggestions.iter().any(|s| s.contains("decision_made")));
458        assert!(
459            suggestions
460                .iter()
461                .any(|s| s.contains("lowercase with underscores"))
462        );
463    }
464
465    #[test]
466    fn test_recovery_suggestions_type_error() {
467        let suggestions = generate_recovery_suggestions(
468            "invalid type: integer `123`, expected a string",
469            Some("session_id"),
470            Some(&json!(123)),
471        );
472
473        assert!(
474            suggestions
475                .iter()
476                .any(|s| s.contains("Convert the number 123"))
477        );
478    }
479
480    #[test]
481    fn test_recovery_suggestions_content_required() {
482        let suggestions =
483            generate_recovery_suggestions("content is required", Some("content"), None);
484
485        assert!(suggestions.iter().any(|s| s.contains("interaction_type")));
486        assert!(suggestions.iter().any(|s| s.contains("bulk updates")));
487    }
488
489    #[test]
490    fn test_recovery_suggestions_session_not_found() {
491        let suggestions = generate_recovery_suggestions("Session not found", None, None);
492
493        assert!(suggestions.iter().any(|s| s.contains("'session' tool")));
494        assert!(suggestions.iter().any(|s| s.contains("semantic_search")));
495    }
496
497    #[test]
498    fn test_recovery_suggestions_includes_suggestions_in_mcp_error() {
499        let error = CoercionError::new(
500            "Invalid UUID format",
501            std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid"),
502            Some(json!("short-id")),
503        )
504        .with_parameter_path("session_id".to_string());
505
506        let mcp_error = error.to_mcp_error();
507        let error_data: serde_json::Value = serde_json::from_str(&mcp_error.message).unwrap();
508
509        // Verify suggestions array exists
510        assert!(error_data["suggestions"].is_array());
511        let suggestions = error_data["suggestions"].as_array().unwrap();
512        assert!(!suggestions.is_empty());
513    }
514
515    #[test]
516    fn test_recovery_suggestions_general_fallback() {
517        let suggestions = generate_recovery_suggestions("Some unknown error", None, None);
518
519        assert!(suggestions.iter().any(|s| s.contains("dry_run")));
520        assert!(suggestions.iter().any(|s| s.contains("parameter types")));
521    }
522}