Skip to main content

rab/agent/
extension.rs

1/// Extension trait - all capability (built-in or user-provided) comes through this.
2use crate::tui::Theme;
3use std::borrow::Cow;
4use std::sync::{
5    Arc,
6    atomic::{AtomicBool, Ordering},
7};
8
9// ── Tool call hooks (matching pi's beforeToolCall / afterToolCall) ──
10
11/// Result returned from `before_tool_call` (matching pi's `BeforeToolCallResult`).
12/// Returning `{ block: true }` prevents execution; `reason` becomes the error text.
13pub struct BeforeToolCallResult {
14    /// If true, the tool execution is blocked.
15    pub block: bool,
16    /// Error message shown when `block` is true. If empty, a default message is used.
17    pub reason: String,
18}
19
20/// Partial override returned from `after_tool_call` (matching pi's `AfterToolCallResult`).
21/// Merge semantics are field-by-field: provided fields replace the original; omitted fields keep their values.
22pub struct AfterToolCallResult {
23    /// If provided, replaces the tool result content array in full.
24    pub content: Option<Vec<yoagent::types::Content>>,
25    /// If provided, replaces the tool result details value in full.
26    pub details: Option<serde_json::Value>,
27    /// If provided, replaces the tool result error flag.
28    pub is_error: Option<bool>,
29}
30
31/// Result returned from `before_compact` (matching pi's `SessionBeforeCompactResult`).
32/// Returning `{ cancel: true }` prevents compaction.
33pub struct BeforeCompactResult {
34    /// If true, compaction is cancelled entirely.
35    pub cancel: bool,
36    /// If provided, uses this summary instead of calling the provider.
37    pub summary: Option<String>,
38    /// Optional details stored with the compaction entry.
39    pub details: Option<serde_json::Value>,
40}
41
42/// A tool bundled with its prompt metadata.
43///
44/// Mirrors pi's `ToolDefinition` which carries `promptSnippet`,
45/// `promptGuidelines` and `prepareArguments` directly on the tool definition.
46pub struct ToolDefinition {
47    pub tool: Box<dyn yoagent::types::AgentTool>,
48    /// One-line snippet for the "Available tools" section of the system prompt.
49    pub snippet: &'static str,
50    /// Guideline bullets for the "Guidelines" section of the system prompt.
51    pub guidelines: &'static [&'static str],
52    /// Optional pre-processing of raw LLM arguments before execute().
53    /// Receives raw arguments, returns normalized arguments or an error message.
54    pub prepare_arguments: Option<fn(serde_json::Value) -> Result<serde_json::Value, String>>,
55    /// Called before tool execution, after argument validation (matching pi's `beforeToolCall`).
56    /// Return `Some(BeforeToolCallResult { block: true, reason: "..." })` to block execution.
57    pub before_tool_call: Option<fn(&serde_json::Value) -> Option<BeforeToolCallResult>>,
58    /// Called after tool execution, before the result is returned (matching pi's `afterToolCall`).
59    pub after_tool_call:
60        Option<fn(&yoagent::types::ToolResult, bool) -> Option<AfterToolCallResult>>,
61    /// Tool-specific renderer for the TUI, bundled with the tool definition
62    /// (pi's renderCall/renderResult live on ToolDefinition).
63    pub renderer: Option<Arc<dyn ToolRenderer>>,
64}
65
66// ── Generic argument type coercion & validation ─────────────────
67
68/// Coerce a single JSON value to match a JSON Schema type (modifies in place).
69/// This handles common LLM mistakes: sending numbers as strings, booleans as strings, etc.
70pub fn coerce_primitive_by_type(schema_type: &str, value: &mut serde_json::Value) {
71    match schema_type {
72        "string" => {
73            if value.is_number() || value.is_boolean() {
74                *value = serde_json::Value::String(match value {
75                    serde_json::Value::Number(n) => n.to_string(),
76                    serde_json::Value::Bool(b) => b.to_string(),
77                    _ => unreachable!(),
78                });
79            } else if value.is_null() {
80                *value = serde_json::Value::String(String::new());
81            } else if value.is_array() || value.is_object() {
82                // TypeBox's Value.Convert stringifies arrays/objects when schema expects string
83                *value =
84                    serde_json::Value::String(serde_json::to_string(value).unwrap_or_default());
85            }
86        }
87        "number" => {
88            if let Some(s) = value.as_str() {
89                if let Ok(n) = s.parse::<f64>() {
90                    *value = serde_json::json!(n);
91                }
92            } else if value.is_boolean() {
93                *value = serde_json::json!(if value.as_bool().unwrap() { 1.0 } else { 0.0 });
94            } else if value.is_null() {
95                *value = serde_json::json!(0.0);
96            }
97        }
98        "integer" => {
99            if let Some(s) = value.as_str() {
100                if let Ok(n) = s.parse::<f64>() {
101                    *value = serde_json::json!(n as i64);
102                }
103            } else if value.is_boolean() {
104                *value = serde_json::json!(if value.as_bool().unwrap() { 1i64 } else { 0i64 });
105            } else if value.is_null() {
106                *value = serde_json::json!(0i64);
107            } else if let Some(n) = value.as_f64() {
108                *value = serde_json::json!(n as i64);
109            }
110        }
111        "boolean" => {
112            if let Some(s) = value.as_str() {
113                match s.trim().to_lowercase().as_str() {
114                    "true" | "1" | "yes" | "on" => *value = serde_json::Value::Bool(true),
115                    "false" | "0" | "no" | "off" => *value = serde_json::Value::Bool(false),
116                    _ => {} // Leave as-is if unrecognized
117                }
118            } else if value.is_number() {
119                *value = serde_json::Value::Bool(value.as_f64().unwrap_or(0.0) != 0.0);
120            } else if value.is_null() {
121                *value = serde_json::Value::Bool(false);
122            }
123        }
124        "null" => {
125            // Pi-compatible: treat empty string, 0, or false as null
126            if value.as_str().is_some_and(|s| s.is_empty())
127                || value.as_f64() == Some(0.0)
128                || value.as_bool() == Some(false)
129            {
130                *value = serde_json::Value::Null;
131            }
132        }
133        "array" => {
134            if !value.is_array() && !value.is_null() {
135                let v = std::mem::take(value);
136                *value = serde_json::Value::Array(vec![v]);
137            } else if value.is_null() {
138                *value = serde_json::Value::Array(vec![]);
139            }
140        }
141        _ => {}
142    }
143}
144
145/// Recursively coerce tool arguments to match a JSON Schema (modifies in place).
146pub fn coerce_with_json_schema(schema: &serde_json::Value, args: &mut serde_json::Value) {
147    // Handle composed schemas (matching pi's coerceWithJsonSchema order)
148    if let Some(all_of) = schema.get("allOf").and_then(|v| v.as_array()) {
149        for sub in all_of {
150            coerce_with_json_schema(sub, args);
151        }
152    }
153
154    if let Some(any_of) = schema.get("anyOf").and_then(|v| v.as_array()) {
155        // Try each anyOf alternative; keep the first that changes the value
156        if !any_of.is_empty() {
157            let original = args.clone();
158            for sub in any_of {
159                let mut candidate = original.clone();
160                coerce_with_json_schema(sub, &mut candidate);
161                if candidate != original {
162                    *args = candidate;
163                    break;
164                }
165            }
166        }
167    }
168
169    if let Some(one_of) = schema.get("oneOf").and_then(|v| v.as_array()) {
170        // Same strategy for oneOf
171        if !one_of.is_empty() {
172            let original = args.clone();
173            for sub in one_of {
174                let mut candidate = original.clone();
175                coerce_with_json_schema(sub, &mut candidate);
176                if candidate != original {
177                    *args = candidate;
178                    break;
179                }
180            }
181        }
182    }
183
184    if !args.is_object() {
185        return;
186    }
187    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
188        return;
189    };
190    for (key, prop_schema) in properties {
191        if args.get(key).is_none() {
192            continue;
193        }
194        let arg_value = args.get_mut(key).unwrap();
195
196        // Try each schema type in order (matching pi's approach of iterating types)
197        let schema_types = collect_schema_types(prop_schema);
198        if !schema_types.is_empty() {
199            // Check if value already matches one of the types
200            let already_matches = schema_types.iter().any(|t| matches_json_type(arg_value, t));
201            if !already_matches {
202                for st in &schema_types {
203                    let before = arg_value.clone();
204                    coerce_primitive_by_type(st, arg_value);
205                    if *arg_value != before {
206                        break;
207                    }
208                }
209            }
210
211            // Recurse into objects and arrays
212            if schema_types.iter().any(|t| t == "object") && arg_value.is_object() {
213                coerce_with_json_schema(prop_schema, arg_value);
214            }
215            if schema_types.iter().any(|t| t == "array")
216                && let Some(items_schema) = prop_schema.get("items")
217                && let Some(arr) = arg_value.as_array_mut()
218            {
219                for item in arr.iter_mut() {
220                    coerce_with_json_schema(items_schema, item);
221                }
222            }
223        }
224    }
225}
226
227/// Collect all type names from a schema property, handling both plain strings and arrays.
228fn collect_schema_types(schema: &serde_json::Value) -> Vec<String> {
229    let type_val = match schema.get("type") {
230        Some(t) => t,
231        None => return vec![],
232    };
233    if let Some(s) = type_val.as_str() {
234        return vec![s.to_string()];
235    }
236    if let Some(arr) = type_val.as_array() {
237        return arr
238            .iter()
239            .filter_map(|t| t.as_str().map(|s| s.to_string()))
240            .collect();
241    }
242    vec![]
243}
244
245// ── Schema validation (matching pi's validateToolArguments) ──────
246
247/// Extracts the effective JSON Schema type from a property schema.
248/// Returns `None` when the schema has no recognizable type.
249fn resolve_schema_type(schema: &serde_json::Value) -> Option<&str> {
250    let type_val = schema.get("type")?;
251    if type_val.is_string() {
252        return type_val.as_str();
253    }
254    if type_val.is_array() {
255        // Use the first non-null type (handles ["string", "null"])
256        // This is still used by validate_tool_arguments for single-type checks
257        return type_val
258            .as_array()
259            .and_then(|arr| arr.iter().find_map(|t| t.as_str().filter(|&s| s != "null")));
260    }
261    None
262}
263fn matches_json_type(value: &serde_json::Value, schema_type: &str) -> bool {
264    match schema_type {
265        "string" => value.is_string(),
266        "number" => value.is_number(),
267        "integer" => value.is_i64() || value.is_u64(),
268        "boolean" => value.is_boolean(),
269        "null" => value.is_null(),
270        "array" => value.is_array(),
271        "object" => value.is_object(),
272        _ => true, // unknown type — don't reject
273    }
274}
275
276/// Check whether a value matches at least one of the schema's types (handles ["string", "null"]).
277fn value_matches_schema_types(schema: &serde_json::Value, value: &serde_json::Value) -> bool {
278    let type_val = match schema.get("type") {
279        Some(t) => t,
280        None => return true,
281    };
282    if type_val.is_string() {
283        return matches_json_type(value, type_val.as_str().unwrap());
284    }
285    if let Some(types) = type_val.as_array() {
286        return types
287            .iter()
288            .filter_map(|t| t.as_str())
289            .any(|t| matches_json_type(value, t));
290    }
291    true
292}
293
294/// Recursively collect validation errors for a value against a JSON Schema.
295/// Path format matches pi's formatValidationPath: "root", "edits", "edits.0.oldText".
296fn collect_validation_errors(
297    schema: &serde_json::Value,
298    value: &serde_json::Value,
299    path: &str,
300    errors: &mut Vec<ValidationError>,
301) {
302    // Root must be an object — every tool schema is "type": "object"
303    if (path.is_empty() || path == "root")
304        && let Some(schema_type) = resolve_schema_type(schema)
305        && schema_type == "object"
306        && !value.is_object()
307    {
308        errors.push(ValidationError {
309            path: path.to_string(),
310            message: "Expected object".to_string(),
311        });
312        return;
313    }
314
315    // Not an object — only check type (won't recurse)
316    if !value.is_object()
317        && let Some(schema_type) = resolve_schema_type(schema)
318        && !matches_json_type(value, schema_type)
319    {
320        let expected = if schema_type == "integer" {
321            "integer"
322        } else {
323            schema_type
324        };
325        errors.push(ValidationError {
326            path: path.to_string(),
327            message: format!("Expected {}", expected),
328        });
329        return;
330    }
331
332    if !value.is_object() {
333        return;
334    }
335
336    let obj = value.as_object().unwrap();
337    let properties = schema.get("properties").and_then(|p| p.as_object());
338    let known_keys: std::collections::HashSet<&str> = properties
339        .map(|p| p.keys().map(|k| k.as_str()).collect())
340        .unwrap_or_default();
341
342    // Check required properties
343    if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
344        for required_val in required {
345            if let Some(required_key) = required_val.as_str()
346                && !obj.contains_key(required_key)
347            {
348                let err_path = if path.is_empty() || path == "root" {
349                    required_key.to_string()
350                } else {
351                    format!("{}.{}", path, required_key)
352                };
353                errors.push(ValidationError {
354                    path: err_path,
355                    message: "Required".to_string(),
356                });
357            }
358        }
359    }
360
361    // Check additionalProperties
362    if schema.get("additionalProperties") == Some(&serde_json::Value::Bool(false)) {
363        for key in obj.keys() {
364            if !known_keys.contains(key.as_str()) {
365                let err_path = if path.is_empty() || path == "root" {
366                    key.clone()
367                } else {
368                    format!("{}.{}", path, key)
369                };
370                errors.push(ValidationError {
371                    path: err_path,
372                    message: "must NOT have additional properties".to_string(),
373                });
374            }
375        }
376    }
377
378    // Validate each property
379    if let Some(props) = properties {
380        for (key, prop_schema) in props {
381            if let Some(val) = value.get(key) {
382                let child_path = if path.is_empty() || path == "root" {
383                    key.clone()
384                } else {
385                    format!("{}.{}", path, key)
386                };
387                validate_property(prop_schema, val, &child_path, errors);
388            }
389        }
390    }
391}
392
393/// Validate a single property value against its schema, recursing into objects/arrays.
394fn validate_property(
395    schema: &serde_json::Value,
396    value: &serde_json::Value,
397    path: &str,
398    errors: &mut Vec<ValidationError>,
399) {
400    // Check type match
401    if !value_matches_schema_types(schema, value) {
402        let schema_type = resolve_schema_type(schema).unwrap_or("unknown");
403        let expected = if schema_type == "integer" {
404            "integer"
405        } else {
406            schema_type
407        };
408        errors.push(ValidationError {
409            path: path.to_string(),
410            message: format!("Expected {}", expected),
411        });
412        return; // Don't recurse into wrong-typed values
413    }
414
415    // Recurse into objects
416    if value.is_object() {
417        // Only recurse if the schema also describes an object
418        let schema_type = resolve_schema_type(schema);
419        if schema_type == Some("object") {
420            collect_validation_errors(schema, value, path, errors);
421        }
422        return;
423    }
424
425    // Recurse into array items
426    if let Some(arr) = value.as_array()
427        && resolve_schema_type(schema) == Some("array")
428        && let Some(items_schema) = schema.get("items")
429    {
430        for (i, item) in arr.iter().enumerate() {
431            let item_path = format!("{}.{}", path, i);
432            validate_property(items_schema, item, &item_path, errors);
433        }
434    }
435}
436
437/// A single validation error, matching pi's TypeBox error structure.
438#[derive(Debug, Clone)]
439pub struct ValidationError {
440    /// Path to the field, e.g. "edits.0.oldText" or "root"
441    pub path: String,
442    /// Error message, e.g. "Required" or "must NOT have additional properties"
443    pub message: String,
444}
445
446/// Validate tool arguments against its JSON Schema (matching pi's validateToolArguments).
447///
448/// Returns `Ok(())` on success, or `Err` with pi-compatible format:
449/// ```text
450/// Validation failed for tool "edit":
451///   - path: Required
452///   - edits[0].oldText: Required
453///
454/// Received arguments:
455/// {
456///   "path": "/foo.txt"
457/// }
458/// ```
459pub fn validate_tool_arguments(
460    tool_name: &str,
461    schema: &serde_json::Value,
462    args: &serde_json::Value,
463) -> Result<(), String> {
464    let mut errors: Vec<ValidationError> = Vec::new();
465    collect_validation_errors(schema, args, "root", &mut errors);
466
467    if errors.is_empty() {
468        return Ok(());
469    }
470
471    let error_lines: Vec<String> = errors
472        .iter()
473        .map(|e| format!("  - {}: {}", e.path, e.message))
474        .collect();
475
476    let pretty_args =
477        serde_json::to_string_pretty(args).unwrap_or_else(|_| "<unprintable>".to_string());
478
479    Err(format!(
480        "Validation failed for tool \"{tool_name}\":\n{}\n\nReceived arguments:\n{pretty_args}",
481        error_lines.join("\n"),
482    ))
483}
484
485/// An autocomplete item for slash command arguments.
486#[derive(Debug, Clone)]
487pub struct AutocompleteItem {
488    /// The value to insert when selected.
489    pub value: String,
490    /// Display label.
491    pub label: String,
492    /// Optional description.
493    pub description: Option<String>,
494}
495
496/// A slash command handler (built-in or extension-provided).
497/// Commands use the same Extension trait as tools - built-ins and
498/// user extensions register commands through a uniform interface.
499pub trait CommandHandler: Send + Sync {
500    /// Execute the command with the given arguments string.
501    fn execute(&self, args: &str) -> anyhow::Result<CommandResult>;
502
503    /// Get argument completions for autocomplete.
504    /// Called when user types `/cmd ` - returns matching autocomplete items.
505    fn argument_completions(&self, _prefix: &str) -> Vec<AutocompleteItem> {
506        vec![]
507    }
508}
509
510/// Result of executing a slash command.
511#[derive(Debug, Clone)]
512pub enum CommandResult {
513    /// Command handled, show this info message.
514    Info(String),
515    /// Command caused a quit request.
516    Quit,
517    /// Command switched the model (new model name).
518    ModelChanged(String),
519    /// Show keyboard shortcuts help overlay.
520    ShowHelp,
521    /// Reload settings, extensions, keybindings, themes from disk.
522    Reloaded,
523    /// Start a new session (clear conversation).
524    NewSession,
525    /// Switch to a different session file.
526    SessionSwitched { path: std::path::PathBuf },
527    /// Show session info (ID, file, messages, tokens, cost).
528    SessionInfo {
529        session_id: String,
530        file_path: Option<std::path::PathBuf>,
531        name: Option<String>,
532        message_count: usize,
533        user_messages: usize,
534        assistant_messages: usize,
535        tool_calls: usize,
536        tool_results: usize,
537        total_tokens: u64,
538        input_tokens: u64,
539        output_tokens: u64,
540        cache_read_tokens: u64,
541        cache_write_tokens: u64,
542        cost: f64,
543    },
544    /// Open session selector UI.
545    OpenSessionSelector,
546    /// Name was set for the session.
547    SessionNamed { name: String },
548    /// Open settings menu.
549    OpenSettings,
550    /// Open model selector UI.
551    OpenModelSelector,
552    /// Enable/disable models for cycling.
553    ScopedModels,
554    /// Export session (HTML default, or specify path).
555    ExportSession { path: Option<String> },
556    /// Import and resume a session from a JSONL file.
557    ImportSession { path: String },
558    /// Share session as a secret GitHub gist.
559    ShareSession,
560    /// Copy last agent message to clipboard.
561    CopyLastMessage,
562    /// Show changelog entries.
563    ShowChangelog,
564    /// Create a new fork from a previous user message.
565    ForkSession { message_id: Option<String> },
566    /// Duplicate the current session at the current position.
567    CloneSession,
568    /// Navigate session tree (switch branches).
569    SessionTree,
570    /// Save project trust decision.
571    TrustDecision { decision: String },
572    /// Configure provider authentication.
573    Login {
574        provider: Option<String>,
575        api_key: Option<String>,
576    },
577    /// Remove provider authentication.
578    Logout { provider: Option<String> },
579    /// Manually compact the session context.
580    CompactSession(Option<String>),
581}
582
583/// A registered slash command.
584pub struct SlashCommand {
585    pub name: String,
586    pub description: String,
587    pub handler: Box<dyn CommandHandler>,
588}
589
590/// Simple cancellation token for tool execution.
591/// Shared between the agent loop and tool execution to signal cancellation.
592#[derive(Debug, Clone)]
593pub struct Cancel {
594    flag: Arc<AtomicBool>,
595}
596
597impl Cancel {
598    pub fn new() -> Self {
599        Self {
600            flag: Arc::new(AtomicBool::new(false)),
601        }
602    }
603
604    /// Check whether cancellation has been requested.
605    pub fn is_cancelled(&self) -> bool {
606        self.flag.load(Ordering::Relaxed)
607    }
608
609    /// Request cancellation.
610    pub fn cancel(&self) {
611        self.flag.store(true, Ordering::Relaxed);
612    }
613
614    /// Check if cancelled, returning an error if so.
615    pub fn check(&self) -> anyhow::Result<()> {
616        if self.is_cancelled() {
617            Err(anyhow::anyhow!("Operation cancelled"))
618        } else {
619            Ok(())
620        }
621    }
622}
623
624impl Default for Cancel {
625    fn default() -> Self {
626        Self::new()
627    }
628}
629
630/// Context passed to ToolRenderer methods (matching pi's ToolRenderContext).
631/// Carries all metadata about the tool execution that renderers may need.
632#[derive(Debug, Clone)]
633pub struct ToolRenderContext {
634    pub expanded: bool,
635    pub args_complete: bool,
636    pub is_partial: bool,
637    pub is_error: bool,
638    /// Unique id for this tool execution (pi's toolCallId).
639    pub tool_call_id: String,
640    /// Whether the tool execution has started (pi's executionStarted).
641    pub execution_started: bool,
642    /// Working directory for path resolution.
643    pub cwd: String,
644    /// Duration in seconds (bash).
645    pub duration_secs: Option<f64>,
646    /// Exit code (bash).
647    pub exit_code: Option<i32>,
648    /// Whether execution was cancelled (bash).
649    pub cancelled: bool,
650    /// Whether output was truncated (bash/read).
651    pub was_truncated: bool,
652    /// Path to full output file (bash).
653    pub full_output_path: Option<String>,
654    /// File path for syntax highlighting (read).
655    pub file_path: Option<String>,
656    /// Keybinding hint for the expand action, e.g. "C-O".
657    pub expand_key: String,
658    /// Structured rendering details from the tool execution (pi-compatible).
659    /// Set by tool renderers for preview/actual diff data. Not sent to the LLM.
660    pub details: Option<serde_json::Value>,
661    /// Shared mutable state per tool execution (pi's context.state).
662    /// Initialized as an empty JSON object `{}`. Renderers can mutate it
663    /// across renderCall/renderResult invocations for the same tool call.
664    pub state: std::rc::Rc<std::cell::RefCell<serde_json::Value>>,
665    /// Callback for renderers to request re-render (e.g. after async preview computation).
666    /// Pi-compatible: `context.invalidate()` in renderCall/renderResult.
667    /// Cloned from the original at context construction time.
668    /// Uses a channel sender internally to bridge from async to UI thread.
669    pub invalidate: Option<tokio::sync::mpsc::UnboundedSender<()>>,
670}
671
672/// Tool-specific rendering interface (matching pi's renderCall/renderResult pattern).
673/// Each built-in tool implements this to provide its own visual representation.
674pub trait ToolRenderer: Send + Sync {
675    /// Render the tool call header/title.
676    /// Returns ANSI-styled lines for the call portion (inside the colored box shell).
677    fn render_call(
678        &self,
679        args: &serde_json::Value,
680        width: usize,
681        theme: &dyn Theme,
682        ctx: &ToolRenderContext,
683    ) -> Vec<String>;
684
685    /// Render the tool result body.
686    /// Returns lines to display as the result body, or empty vec for no result.
687    /// When empty, only the call portion is shown (e.g. write success).
688    fn render_result(
689        &self,
690        content: &str,
691        width: usize,
692        theme: &dyn Theme,
693        ctx: &ToolRenderContext,
694    ) -> Vec<String>;
695
696    /// Whether this tool uses `renderShell: "self"` (controls its own framing).
697    /// When true, ToolExecComponent does NOT wrap the tool in a colored background box.
698    fn render_self(&self) -> bool {
699        false
700    }
701
702    /// Optional hint for the background color key when `render_self()` is false.
703    /// Return a theme key name (e.g. "toolPendingBg", "toolSuccessBg", "toolErrorBg")
704    /// to override the default background selection. Return None to let the
705    /// ToolExecComponent decide based on is_complete/is_error state.
706    /// Used by edit tool to show success/error bg during preview.
707    fn render_bg_key(&self) -> Option<&'static str> {
708        None
709    }
710}
711
712#[async_trait::async_trait]
713impl yoagent::types::AgentTool for ToolDefinition {
714    fn name(&self) -> &str {
715        self.tool.name()
716    }
717
718    fn label(&self) -> &str {
719        self.tool.label()
720    }
721
722    fn description(&self) -> &str {
723        self.tool.description()
724    }
725
726    fn parameters_schema(&self) -> serde_json::Value {
727        self.tool.parameters_schema()
728    }
729
730    async fn execute(
731        &self,
732        params: serde_json::Value,
733        ctx: yoagent::types::ToolContext,
734    ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
735        let mut params = match self.prepare_arguments {
736            Some(prepare) => prepare(params).map_err(yoagent::types::ToolError::InvalidArgs)?,
737            None => params,
738        };
739        // Step 1: type coercion (matching pi's Value.Convert + coerceWithJsonSchema)
740        let schema = self.tool.parameters_schema();
741        coerce_with_json_schema(&schema, &mut params);
742
743        // Step 2: validate against schema (matching pi's validateToolArguments)
744        let tool_name = self.tool.name();
745        validate_tool_arguments(tool_name, &schema, &params)
746            .map_err(yoagent::types::ToolError::InvalidArgs)?;
747
748        // Step 3: before_tool_call hook (matching pi's beforeToolCall)
749        if let Some(ref hook) = self.before_tool_call
750            && let Some(result) = hook(&params)
751            && result.block
752        {
753            let reason = if result.reason.is_empty() {
754                format!("Tool {} execution blocked", tool_name)
755            } else {
756                result.reason
757            };
758            return Err(yoagent::types::ToolError::Failed(reason));
759        }
760
761        // Step 4: execute the inner tool
762        let (mut tool_result, mut is_error) = match self.tool.execute(params, ctx).await {
763            Ok(r) => (r, false),
764            Err(e) => {
765                let err_text = e.to_string();
766                (
767                    yoagent::types::ToolResult {
768                        content: vec![yoagent::types::Content::Text { text: err_text }],
769                        details: serde_json::Value::Null,
770                    },
771                    true,
772                )
773            }
774        };
775
776        // Step 5: after_tool_call hook (matching pi's afterToolCall)
777        if let Some(ref hook) = self.after_tool_call
778            && let Some(override_result) = hook(&tool_result, is_error)
779        {
780            if let Some(content) = override_result.content {
781                tool_result.content = content;
782            }
783            if let Some(details) = override_result.details {
784                tool_result.details = details;
785            }
786            if let Some(err) = override_result.is_error {
787                is_error = err;
788            }
789        }
790
791        if is_error {
792            let error_text: String = tool_result
793                .content
794                .iter()
795                .filter_map(|c| {
796                    if let yoagent::types::Content::Text { text } = c {
797                        Some(text.as_str())
798                    } else {
799                        None
800                    }
801                })
802                .collect::<Vec<_>>()
803                .join("\n");
804            Err(yoagent::types::ToolError::Failed(error_text))
805        } else {
806            Ok(tool_result)
807        }
808    }
809}
810
811pub trait Extension: Send + Sync {
812    fn name(&self) -> Cow<'static, str>;
813
814    /// Tools this extension provides (LLM-callable), each with its own prompt metadata.
815    fn tools(&self) -> Vec<ToolDefinition> {
816        vec![]
817    }
818
819    /// Slash commands this extension provides (e.g. `/quit`, `/model`).
820    /// Built-in commands and extension commands use the same interface.
821    fn commands(&self) -> Vec<SlashCommand> {
822        vec![]
823    }
824
825    /// Skills this extension provides (AgentSkills-compatible).
826    /// Merged into the session's skill set for /skill:name expansion and system prompt.
827    fn skills(&self) -> yoagent::skills::SkillSet {
828        yoagent::skills::SkillSet::empty()
829    }
830
831    /// Called before compaction runs (matching pi's `session_before_compact`).
832    /// Return `Some(BeforeCompactResult { cancel: true, .. })` to cancel compaction.
833    /// Return `Some(BeforeCompactResult { cancel: false, summary: Some(...), .. })`
834    /// to provide a custom summary instead of calling the provider.
835    /// Return `None` to let the default compaction proceed.
836    ///
837    /// `cancel` is a cancellation token — check `cancel.is_cancelled()` in
838    /// long-running hooks and return immediately if true (matching pi's
839    /// `AbortSignal` passed to `session_before_compact`).
840    fn before_compact(
841        &self,
842        _first_kept_entry_id: &str,
843        _tokens_before: u64,
844        _reason: &str,
845        _cancel: &Cancel,
846    ) -> Option<BeforeCompactResult> {
847        None
848    }
849
850    /// Called after compaction completes (matching pi's `session_compact`).
851    ///
852    /// `cancel` is a cancellation token — check `cancel.is_cancelled()` in
853    /// long-running hooks and return early if true (matching pi's
854    /// `AbortSignal` passed to `session_compact`).
855    #[allow(clippy::too_many_arguments)]
856    fn after_compact(
857        &self,
858        _summary: &str,
859        _first_kept_entry_id: &str,
860        _tokens_before: u64,
861        _estimated_tokens_after: u64,
862        _from_hook: bool,
863        _reason: &str,
864        _cancel: &Cancel,
865    ) {
866    }
867}
868
869// ── Tests ──────────────────────────────────────────────────────────
870
871#[cfg(test)]
872mod tests {
873    use super::*;
874
875    // ── coerce_primitive_by_type ────────────────────────────────────
876
877    #[test]
878    fn test_coerce_string_from_number() {
879        let mut v = serde_json::json!(42);
880        coerce_primitive_by_type("string", &mut v);
881        assert_eq!(v, serde_json::json!("42"));
882    }
883
884    #[test]
885    fn test_coerce_string_from_boolean() {
886        let mut v = serde_json::json!(true);
887        coerce_primitive_by_type("string", &mut v);
888        assert_eq!(v, serde_json::json!("true"));
889    }
890
891    #[test]
892    fn test_coerce_string_from_null() {
893        let mut v = serde_json::json!(null);
894        coerce_primitive_by_type("string", &mut v);
895        assert_eq!(v, serde_json::json!(""));
896    }
897
898    #[test]
899    fn test_coerce_string_unchanged() {
900        let mut v = serde_json::json!("hello");
901        coerce_primitive_by_type("string", &mut v);
902        assert_eq!(v, serde_json::json!("hello"));
903    }
904
905    #[test]
906    fn test_coerce_number_from_string() {
907        let mut v = serde_json::json!("42.5");
908        coerce_primitive_by_type("number", &mut v);
909        assert_eq!(v, serde_json::json!(42.5));
910    }
911
912    #[test]
913    fn test_coerce_number_from_boolean() {
914        let mut v = serde_json::json!(true);
915        coerce_primitive_by_type("number", &mut v);
916        assert_eq!(v, serde_json::json!(1.0));
917    }
918
919    #[test]
920    fn test_coerce_number_from_null() {
921        let mut v = serde_json::json!(null);
922        coerce_primitive_by_type("number", &mut v);
923        assert_eq!(v, serde_json::json!(0.0));
924    }
925
926    #[test]
927    fn test_coerce_integer_from_string() {
928        let mut v = serde_json::json!("7");
929        coerce_primitive_by_type("integer", &mut v);
930        assert_eq!(v, serde_json::json!(7i64));
931    }
932
933    #[test]
934    fn test_coerce_integer_from_float() {
935        let mut v = serde_json::json!(3.9);
936        coerce_primitive_by_type("integer", &mut v);
937        assert_eq!(v, serde_json::json!(3i64));
938    }
939
940    #[test]
941    fn test_coerce_integer_from_boolean() {
942        let mut v = serde_json::json!(false);
943        coerce_primitive_by_type("integer", &mut v);
944        assert_eq!(v, serde_json::json!(0i64));
945    }
946
947    #[test]
948    fn test_coerce_boolean_from_string_true() {
949        let mut v = serde_json::json!("true");
950        coerce_primitive_by_type("boolean", &mut v);
951        assert_eq!(v, serde_json::json!(true));
952    }
953
954    #[test]
955    fn test_coerce_boolean_from_string_yes() {
956        let mut v = serde_json::json!("yes");
957        coerce_primitive_by_type("boolean", &mut v);
958        assert_eq!(v, serde_json::json!(true));
959    }
960
961    #[test]
962    fn test_coerce_boolean_from_number() {
963        let mut v = serde_json::json!(1);
964        coerce_primitive_by_type("boolean", &mut v);
965        assert_eq!(v, serde_json::json!(true));
966    }
967
968    #[test]
969    fn test_coerce_boolean_from_null() {
970        let mut v = serde_json::json!(null);
971        coerce_primitive_by_type("boolean", &mut v);
972        assert_eq!(v, serde_json::json!(false));
973    }
974
975    #[test]
976    fn test_coerce_array_from_scalar() {
977        let mut v = serde_json::json!("single");
978        coerce_primitive_by_type("array", &mut v);
979        assert_eq!(v, serde_json::json!(["single"]));
980    }
981
982    #[test]
983    fn test_coerce_array_from_null() {
984        let mut v = serde_json::json!(null);
985        coerce_primitive_by_type("array", &mut v);
986        assert_eq!(v, serde_json::json!([]));
987    }
988
989    #[test]
990    fn test_coerce_array_unchanged() {
991        let mut v = serde_json::json!([1, 2, 3]);
992        coerce_primitive_by_type("array", &mut v);
993        assert_eq!(v, serde_json::json!([1, 2, 3]));
994    }
995
996    #[test]
997    fn test_coerce_unknown_type_does_nothing() {
998        let mut v = serde_json::json!(42);
999        coerce_primitive_by_type("widget", &mut v);
1000        assert_eq!(v, serde_json::json!(42));
1001    }
1002
1003    // ── coerce_with_json_schema ─────────────────────────────────────
1004
1005    #[test]
1006    fn test_coerce_schema_string_from_number() {
1007        let schema = serde_json::json!({
1008            "type": "object",
1009            "properties": {
1010                "name": {"type": "string"}
1011            }
1012        });
1013        let mut args = serde_json::json!({"name": 42});
1014        coerce_with_json_schema(&schema, &mut args);
1015        assert_eq!(args, serde_json::json!({"name": "42"}));
1016    }
1017
1018    #[test]
1019    fn test_coerce_schema_nested_object() {
1020        let schema = serde_json::json!({
1021            "type": "object",
1022            "properties": {
1023                "metadata": {
1024                    "type": "object",
1025                    "properties": {
1026                        "count": {"type": "integer"}
1027                    }
1028                }
1029            }
1030        });
1031        let mut args = serde_json::json!({"metadata": {"count": "5"}});
1032        coerce_with_json_schema(&schema, &mut args);
1033        assert_eq!(args, serde_json::json!({"metadata": {"count": 5i64}}));
1034    }
1035
1036    #[test]
1037    fn test_coerce_schema_array_items() {
1038        let schema = serde_json::json!({
1039            "type": "object",
1040            "properties": {
1041                "items": {
1042                    "type": "array",
1043                    "items": {
1044                        "type": "object",
1045                        "properties": {
1046                            "id": {"type": "integer"}
1047                        }
1048                    }
1049                }
1050            }
1051        });
1052        let mut args = serde_json::json!({"items": [{"id": "3"}, {"id": "7"}]});
1053        coerce_with_json_schema(&schema, &mut args);
1054        assert_eq!(
1055            args,
1056            serde_json::json!({"items": [{"id": 3i64}, {"id": 7i64}]})
1057        );
1058    }
1059
1060    #[test]
1061    fn test_coerce_schema_non_object_skipped() {
1062        let schema = serde_json::json!({"type": "string"});
1063        let mut args = serde_json::json!("hello");
1064        coerce_with_json_schema(&schema, &mut args);
1065        assert_eq!(args, serde_json::json!("hello"));
1066    }
1067
1068    // ── validate_tool_arguments ─────────────────────────────────────
1069
1070    #[test]
1071    fn test_validate_valid_args() {
1072        let schema = serde_json::json!({
1073            "type": "object",
1074            "properties": {
1075                "path": {"type": "string"}
1076            },
1077            "required": ["path"]
1078        });
1079        let args = serde_json::json!({"path": "/tmp/foo.txt"});
1080        assert!(validate_tool_arguments("test", &schema, &args).is_ok());
1081    }
1082
1083    #[test]
1084    fn test_validate_missing_required() {
1085        let schema = serde_json::json!({
1086            "type": "object",
1087            "properties": {
1088                "path": {"type": "string"}
1089            },
1090            "required": ["path"]
1091        });
1092        let args = serde_json::json!({});
1093        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1094        assert!(err.contains("Required"));
1095        assert!(err.contains("test"));
1096    }
1097
1098    #[test]
1099    fn test_validate_wrong_type() {
1100        let schema = serde_json::json!({
1101            "type": "object",
1102            "properties": {
1103                "count": {"type": "integer"}
1104            }
1105        });
1106        let args = serde_json::json!({"count": "not-a-number"});
1107        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1108        assert!(err.contains("Expected integer"));
1109    }
1110
1111    #[test]
1112    fn test_validate_additional_properties() {
1113        let schema = serde_json::json!({
1114            "type": "object",
1115            "properties": {
1116                "name": {"type": "string"}
1117            },
1118            "additionalProperties": false
1119        });
1120        let args = serde_json::json!({"name": "alice", "extra": "bad"});
1121        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1122        assert!(err.contains("must NOT have additional properties"));
1123    }
1124
1125    #[test]
1126    fn test_validate_not_an_object() {
1127        let schema = serde_json::json!({
1128            "type": "object",
1129            "properties": {}
1130        });
1131        let args = serde_json::json!("a string, not an object");
1132        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1133        assert!(err.contains("Expected object"));
1134    }
1135
1136    #[test]
1137    fn test_validate_array_item_types() {
1138        let schema = serde_json::json!({
1139            "type": "object",
1140            "properties": {
1141                "tags": {
1142                    "type": "array",
1143                    "items": {"type": "string"}
1144                }
1145            }
1146        });
1147        let args = serde_json::json!({"tags": [1, 2, 3]});
1148        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1149        assert!(err.contains("Expected string"));
1150    }
1151
1152    // ── Cancel ──────────────────────────────────────────────────────
1153
1154    #[test]
1155    fn test_cancel_new_not_cancelled() {
1156        let cancel = Cancel::new();
1157        assert!(!cancel.is_cancelled());
1158        cancel.check().unwrap();
1159    }
1160
1161    #[test]
1162    fn test_cancel_after_cancel() {
1163        let cancel = Cancel::new();
1164        cancel.cancel();
1165        assert!(cancel.is_cancelled());
1166        assert!(cancel.check().is_err());
1167    }
1168
1169    #[test]
1170    fn test_cancel_default_not_cancelled() {
1171        let cancel = Cancel::default();
1172        assert!(!cancel.is_cancelled());
1173    }
1174
1175    #[test]
1176    fn test_cancel_is_send_sync() {
1177        fn assert_send<T: Send>() {}
1178        fn assert_sync<T: Sync>() {}
1179        assert_send::<Cancel>();
1180        assert_sync::<Cancel>();
1181    }
1182}