Skip to main content

rab/agent/
extension.rs

1/// Extension trait - all capability (built-in or user-provided) comes through this.
2use crate::tui::Theme;
3use std::borrow::Cow;
4use std::sync::{
5    Arc,
6    atomic::{AtomicBool, Ordering},
7};
8
9// ── Tool call hooks (matching pi's beforeToolCall / afterToolCall) ──
10
11/// Result returned from `before_tool_call` (matching pi's `BeforeToolCallResult`).
12/// Returning `{ block: true }` prevents execution; `reason` becomes the error text.
13pub struct BeforeToolCallResult {
14    /// If true, the tool execution is blocked.
15    pub block: bool,
16    /// Error message shown when `block` is true. If empty, a default message is used.
17    pub reason: String,
18}
19
20/// Partial override returned from `after_tool_call` (matching pi's `AfterToolCallResult`).
21/// Merge semantics are field-by-field: provided fields replace the original; omitted fields keep their values.
22pub struct AfterToolCallResult {
23    /// If provided, replaces the tool result content array in full.
24    pub content: Option<Vec<yoagent::types::Content>>,
25    /// If provided, replaces the tool result details value in full.
26    pub details: Option<serde_json::Value>,
27    /// If provided, replaces the tool result error flag.
28    pub is_error: Option<bool>,
29}
30
31/// Result returned from `before_compact` (matching pi's `SessionBeforeCompactResult`).
32/// Returning `{ cancel: true }` prevents compaction.
33pub struct BeforeCompactResult {
34    /// If true, compaction is cancelled entirely.
35    pub cancel: bool,
36    /// If provided, uses this summary instead of calling the provider.
37    pub summary: Option<String>,
38    /// Optional details stored with the compaction entry.
39    pub details: Option<serde_json::Value>,
40}
41
42/// A tool bundled with its prompt metadata.
43///
44/// Mirrors pi's `ToolDefinition` which carries `promptSnippet`,
45/// `promptGuidelines` and `prepareArguments` directly on the tool definition.
46pub struct ToolDefinition {
47    pub tool: Box<dyn yoagent::types::AgentTool>,
48    /// One-line snippet for the "Available tools" section of the system prompt.
49    pub snippet: &'static str,
50    /// Guideline bullets for the "Guidelines" section of the system prompt.
51    pub guidelines: &'static [&'static str],
52    /// Optional pre-processing of raw LLM arguments before execute().
53    /// Receives raw arguments, returns normalized arguments or an error message.
54    pub prepare_arguments: Option<fn(serde_json::Value) -> Result<serde_json::Value, String>>,
55    /// Called before tool execution, after argument validation (matching pi's `beforeToolCall`).
56    /// Return `Some(BeforeToolCallResult { block: true, reason: "..." })` to block execution.
57    pub before_tool_call: Option<fn(&serde_json::Value) -> Option<BeforeToolCallResult>>,
58    /// Called after tool execution, before the result is returned (matching pi's `afterToolCall`).
59    pub after_tool_call:
60        Option<fn(&yoagent::types::ToolResult, bool) -> Option<AfterToolCallResult>>,
61    /// Tool-specific renderer for the TUI, bundled with the tool definition
62    /// (pi's renderCall/renderResult live on ToolDefinition).
63    pub renderer: Option<Arc<dyn ToolRenderer>>,
64}
65
66// ── Generic argument type coercion & validation ─────────────────
67
68/// Coerce a single JSON value to match a JSON Schema type (modifies in place).
69/// This handles common LLM mistakes: sending numbers as strings, booleans as strings, etc.
70pub fn coerce_primitive_by_type(schema_type: &str, value: &mut serde_json::Value) {
71    match schema_type {
72        "string" => {
73            if value.is_number() || value.is_boolean() {
74                *value = serde_json::Value::String(match value {
75                    serde_json::Value::Number(n) => n.to_string(),
76                    serde_json::Value::Bool(b) => b.to_string(),
77                    _ => unreachable!(),
78                });
79            } else if value.is_null() {
80                *value = serde_json::Value::String(String::new());
81            } else if value.is_array() || value.is_object() {
82                // TypeBox's Value.Convert stringifies arrays/objects when schema expects string
83                *value =
84                    serde_json::Value::String(serde_json::to_string(value).unwrap_or_default());
85            }
86        }
87        "number" => {
88            if let Some(s) = value.as_str() {
89                if let Ok(n) = s.parse::<f64>() {
90                    *value = serde_json::json!(n);
91                }
92            } else if value.is_boolean() {
93                *value = serde_json::json!(if value.as_bool().unwrap() { 1.0 } else { 0.0 });
94            } else if value.is_null() {
95                *value = serde_json::json!(0.0);
96            }
97        }
98        "integer" => {
99            if let Some(s) = value.as_str() {
100                if let Ok(n) = s.parse::<f64>() {
101                    *value = serde_json::json!(n as i64);
102                }
103            } else if value.is_boolean() {
104                *value = serde_json::json!(if value.as_bool().unwrap() { 1i64 } else { 0i64 });
105            } else if value.is_null() {
106                *value = serde_json::json!(0i64);
107            } else if let Some(n) = value.as_f64() {
108                *value = serde_json::json!(n as i64);
109            }
110        }
111        "boolean" => {
112            if let Some(s) = value.as_str() {
113                match s.trim().to_lowercase().as_str() {
114                    "true" | "1" | "yes" | "on" => *value = serde_json::Value::Bool(true),
115                    "false" | "0" | "no" | "off" => *value = serde_json::Value::Bool(false),
116                    _ => {} // Leave as-is if unrecognized
117                }
118            } else if value.is_number() {
119                *value = serde_json::Value::Bool(value.as_f64().unwrap_or(0.0) != 0.0);
120            } else if value.is_null() {
121                *value = serde_json::Value::Bool(false);
122            }
123        }
124        "null" => {
125            // Pi-compatible: treat empty string, 0, or false as null
126            if value.as_str().is_some_and(|s| s.is_empty())
127                || value.as_f64() == Some(0.0)
128                || value.as_bool() == Some(false)
129            {
130                *value = serde_json::Value::Null;
131            }
132        }
133        "array" => {
134            if !value.is_array() && !value.is_null() {
135                let v = std::mem::take(value);
136                *value = serde_json::Value::Array(vec![v]);
137            } else if value.is_null() {
138                *value = serde_json::Value::Array(vec![]);
139            }
140        }
141        _ => {}
142    }
143}
144
145/// Recursively coerce tool arguments to match a JSON Schema (modifies in place).
146pub fn coerce_with_json_schema(schema: &serde_json::Value, args: &mut serde_json::Value) {
147    // Handle composed schemas (matching pi's coerceWithJsonSchema order)
148    if let Some(all_of) = schema.get("allOf").and_then(|v| v.as_array()) {
149        for sub in all_of {
150            coerce_with_json_schema(sub, args);
151        }
152    }
153
154    if let Some(any_of) = schema.get("anyOf").and_then(|v| v.as_array()) {
155        // Try each anyOf alternative; keep the first that changes the value
156        if !any_of.is_empty() {
157            let original = args.clone();
158            for sub in any_of {
159                let mut candidate = original.clone();
160                coerce_with_json_schema(sub, &mut candidate);
161                if candidate != original {
162                    *args = candidate;
163                    break;
164                }
165            }
166        }
167    }
168
169    if let Some(one_of) = schema.get("oneOf").and_then(|v| v.as_array()) {
170        // Same strategy for oneOf
171        if !one_of.is_empty() {
172            let original = args.clone();
173            for sub in one_of {
174                let mut candidate = original.clone();
175                coerce_with_json_schema(sub, &mut candidate);
176                if candidate != original {
177                    *args = candidate;
178                    break;
179                }
180            }
181        }
182    }
183
184    if !args.is_object() {
185        return;
186    }
187    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
188        return;
189    };
190    for (key, prop_schema) in properties {
191        if args.get(key).is_none() {
192            continue;
193        }
194        let arg_value = args.get_mut(key).unwrap();
195
196        // Try each schema type in order (matching pi's approach of iterating types)
197        let schema_types = collect_schema_types(prop_schema);
198        if !schema_types.is_empty() {
199            // Check if value already matches one of the types
200            let already_matches = schema_types.iter().any(|t| matches_json_type(arg_value, t));
201            if !already_matches {
202                for st in &schema_types {
203                    let before = arg_value.clone();
204                    coerce_primitive_by_type(st, arg_value);
205                    if *arg_value != before {
206                        break;
207                    }
208                }
209            }
210
211            // Recurse into objects and arrays
212            if schema_types.iter().any(|t| t == "object") && arg_value.is_object() {
213                coerce_with_json_schema(prop_schema, arg_value);
214            }
215            if schema_types.iter().any(|t| t == "array")
216                && let Some(items_schema) = prop_schema.get("items")
217                && let Some(arr) = arg_value.as_array_mut()
218            {
219                for item in arr.iter_mut() {
220                    coerce_with_json_schema(items_schema, item);
221                }
222            }
223        }
224    }
225}
226
227/// Collect all type names from a schema property, handling both plain strings and arrays.
228fn collect_schema_types(schema: &serde_json::Value) -> Vec<String> {
229    let type_val = match schema.get("type") {
230        Some(t) => t,
231        None => return vec![],
232    };
233    if let Some(s) = type_val.as_str() {
234        return vec![s.to_string()];
235    }
236    if let Some(arr) = type_val.as_array() {
237        return arr
238            .iter()
239            .filter_map(|t| t.as_str().map(|s| s.to_string()))
240            .collect();
241    }
242    vec![]
243}
244
245// ── Schema validation (matching pi's validateToolArguments) ──────
246
247/// Extracts the effective JSON Schema type from a property schema.
248/// Returns `None` when the schema has no recognizable type.
249fn resolve_schema_type(schema: &serde_json::Value) -> Option<&str> {
250    let type_val = schema.get("type")?;
251    if type_val.is_string() {
252        return type_val.as_str();
253    }
254    if type_val.is_array() {
255        // Use the first non-null type (handles ["string", "null"])
256        // This is still used by validate_tool_arguments for single-type checks
257        return type_val
258            .as_array()
259            .and_then(|arr| arr.iter().find_map(|t| t.as_str().filter(|&s| s != "null")));
260    }
261    None
262}
263fn matches_json_type(value: &serde_json::Value, schema_type: &str) -> bool {
264    match schema_type {
265        "string" => value.is_string(),
266        "number" => value.is_number(),
267        "integer" => value.is_i64() || value.is_u64(),
268        "boolean" => value.is_boolean(),
269        "null" => value.is_null(),
270        "array" => value.is_array(),
271        "object" => value.is_object(),
272        _ => true, // unknown type — don't reject
273    }
274}
275
276/// Check whether a value matches at least one of the schema's types (handles ["string", "null"]).
277fn value_matches_schema_types(schema: &serde_json::Value, value: &serde_json::Value) -> bool {
278    let type_val = match schema.get("type") {
279        Some(t) => t,
280        None => return true,
281    };
282    if type_val.is_string() {
283        return matches_json_type(value, type_val.as_str().unwrap());
284    }
285    if let Some(types) = type_val.as_array() {
286        return types
287            .iter()
288            .filter_map(|t| t.as_str())
289            .any(|t| matches_json_type(value, t));
290    }
291    true
292}
293
294/// Recursively collect validation errors for a value against a JSON Schema.
295/// Path format matches pi's formatValidationPath: "root", "edits", "edits.0.oldText".
296fn collect_validation_errors(
297    schema: &serde_json::Value,
298    value: &serde_json::Value,
299    path: &str,
300    errors: &mut Vec<ValidationError>,
301) {
302    // Root must be an object — every tool schema is "type": "object"
303    if (path.is_empty() || path == "root")
304        && let Some(schema_type) = resolve_schema_type(schema)
305        && schema_type == "object"
306        && !value.is_object()
307    {
308        errors.push(ValidationError {
309            path: path.to_string(),
310            message: "Expected object".to_string(),
311        });
312        return;
313    }
314
315    // Not an object — only check type (won't recurse)
316    if !value.is_object()
317        && let Some(schema_type) = resolve_schema_type(schema)
318        && !matches_json_type(value, schema_type)
319    {
320        let expected = if schema_type == "integer" {
321            "integer"
322        } else {
323            schema_type
324        };
325        errors.push(ValidationError {
326            path: path.to_string(),
327            message: format!("Expected {}", expected),
328        });
329        return;
330    }
331
332    if !value.is_object() {
333        return;
334    }
335
336    let obj = value.as_object().unwrap();
337    let properties = schema.get("properties").and_then(|p| p.as_object());
338    let known_keys: std::collections::HashSet<&str> = properties
339        .map(|p| p.keys().map(|k| k.as_str()).collect())
340        .unwrap_or_default();
341
342    // Check required properties
343    if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
344        for required_val in required {
345            if let Some(required_key) = required_val.as_str()
346                && !obj.contains_key(required_key)
347            {
348                let err_path = if path.is_empty() || path == "root" {
349                    required_key.to_string()
350                } else {
351                    format!("{}.{}", path, required_key)
352                };
353                errors.push(ValidationError {
354                    path: err_path,
355                    message: "Required".to_string(),
356                });
357            }
358        }
359    }
360
361    // Check additionalProperties
362    if schema.get("additionalProperties") == Some(&serde_json::Value::Bool(false)) {
363        for key in obj.keys() {
364            if !known_keys.contains(key.as_str()) {
365                let err_path = if path.is_empty() || path == "root" {
366                    key.clone()
367                } else {
368                    format!("{}.{}", path, key)
369                };
370                errors.push(ValidationError {
371                    path: err_path,
372                    message: "must NOT have additional properties".to_string(),
373                });
374            }
375        }
376    }
377
378    // Validate each property
379    if let Some(props) = properties {
380        for (key, prop_schema) in props {
381            if let Some(val) = value.get(key) {
382                let child_path = if path.is_empty() || path == "root" {
383                    key.clone()
384                } else {
385                    format!("{}.{}", path, key)
386                };
387                validate_property(prop_schema, val, &child_path, errors);
388            }
389        }
390    }
391}
392
393/// Validate a single property value against its schema, recursing into objects/arrays.
394fn validate_property(
395    schema: &serde_json::Value,
396    value: &serde_json::Value,
397    path: &str,
398    errors: &mut Vec<ValidationError>,
399) {
400    // Check type match
401    if !value_matches_schema_types(schema, value) {
402        let schema_type = resolve_schema_type(schema).unwrap_or("unknown");
403        let expected = if schema_type == "integer" {
404            "integer"
405        } else {
406            schema_type
407        };
408        errors.push(ValidationError {
409            path: path.to_string(),
410            message: format!("Expected {}", expected),
411        });
412        return; // Don't recurse into wrong-typed values
413    }
414
415    // Recurse into objects
416    if value.is_object() {
417        // Only recurse if the schema also describes an object
418        let schema_type = resolve_schema_type(schema);
419        if schema_type == Some("object") {
420            collect_validation_errors(schema, value, path, errors);
421        }
422        return;
423    }
424
425    // Recurse into array items
426    if let Some(arr) = value.as_array()
427        && resolve_schema_type(schema) == Some("array")
428        && let Some(items_schema) = schema.get("items")
429    {
430        for (i, item) in arr.iter().enumerate() {
431            let item_path = format!("{}.{}", path, i);
432            validate_property(items_schema, item, &item_path, errors);
433        }
434    }
435}
436
437/// A single validation error, matching pi's TypeBox error structure.
438#[derive(Debug, Clone)]
439pub struct ValidationError {
440    /// Path to the field, e.g. "edits.0.oldText" or "root"
441    pub path: String,
442    /// Error message, e.g. "Required" or "must NOT have additional properties"
443    pub message: String,
444}
445
446/// Validate tool arguments against its JSON Schema (matching pi's validateToolArguments).
447///
448/// Returns `Ok(())` on success, or `Err` with pi-compatible format:
449/// ```text
450/// Validation failed for tool "edit":
451///   - path: Required
452///   - edits[0].oldText: Required
453///
454/// Received arguments:
455/// {
456///   "path": "/foo.txt"
457/// }
458/// ```
459pub fn validate_tool_arguments(
460    tool_name: &str,
461    schema: &serde_json::Value,
462    args: &serde_json::Value,
463) -> Result<(), String> {
464    let mut errors: Vec<ValidationError> = Vec::new();
465    collect_validation_errors(schema, args, "root", &mut errors);
466
467    if errors.is_empty() {
468        return Ok(());
469    }
470
471    let error_lines: Vec<String> = errors
472        .iter()
473        .map(|e| format!("  - {}: {}", e.path, e.message))
474        .collect();
475
476    let pretty_args =
477        serde_json::to_string_pretty(args).unwrap_or_else(|_| "<unprintable>".to_string());
478
479    Err(format!(
480        "Validation failed for tool \"{tool_name}\":\n{}\n\nReceived arguments:\n{pretty_args}",
481        error_lines.join("\n"),
482    ))
483}
484
485/// An autocomplete item for slash command arguments.
486#[derive(Debug, Clone)]
487pub struct AutocompleteItem {
488    /// The value to insert when selected.
489    pub value: String,
490    /// Display label.
491    pub label: String,
492    /// Optional description.
493    pub description: Option<String>,
494}
495
496/// A slash command handler (built-in or extension-provided).
497/// Commands use the same Extension trait as tools - built-ins and
498/// user extensions register commands through a uniform interface.
499pub trait CommandHandler: Send + Sync {
500    /// Execute the command with the given arguments string.
501    fn execute(&self, args: &str) -> anyhow::Result<CommandResult>;
502
503    /// Get argument completions for autocomplete.
504    /// Called when user types `/cmd ` - returns matching autocomplete items.
505    fn argument_completions(&self, _prefix: &str) -> Vec<AutocompleteItem> {
506        vec![]
507    }
508}
509
510/// Result of executing a slash command.
511#[derive(Debug, Clone)]
512pub enum CommandResult {
513    /// Command handled, show this info message.
514    Info(String),
515    /// Command caused a quit request.
516    Quit,
517    /// Command switched the model (new model name).
518    ModelChanged(String),
519    /// Show keyboard shortcuts help overlay.
520    ShowHelp,
521    /// Reload settings, extensions, keybindings, themes from disk.
522    Reloaded,
523    /// Start a new session (clear conversation).
524    NewSession,
525    /// Switch to a different session file.
526    SessionSwitched { path: std::path::PathBuf },
527    /// Show session info (ID, file, messages, tokens, cost).
528    SessionInfo {
529        session_id: String,
530        file_path: Option<std::path::PathBuf>,
531        name: Option<String>,
532        message_count: usize,
533        user_messages: usize,
534        assistant_messages: usize,
535        tool_calls: usize,
536        tool_results: usize,
537        total_tokens: u64,
538        input_tokens: u64,
539        output_tokens: u64,
540        cache_read_tokens: u64,
541        cache_write_tokens: u64,
542        cost: f64,
543    },
544    /// Open session selector UI.
545    OpenSessionSelector,
546    /// Name was set for the session.
547    SessionNamed { name: String },
548    /// Open settings menu.
549    OpenSettings,
550    /// Open model selector UI.
551    OpenModelSelector,
552    /// Enable/disable models for cycling.
553    ScopedModels,
554    /// Export session (HTML default, or specify path).
555    ExportSession { path: Option<String> },
556    /// Import and resume a session from a JSONL file.
557    ImportSession { path: String },
558    /// Share session as a secret GitHub gist.
559    ShareSession,
560    /// Copy last agent message to clipboard.
561    CopyLastMessage,
562    /// Show changelog entries.
563    ShowChangelog,
564    /// Create a new fork from a previous user message.
565    ForkSession { message_id: Option<String> },
566    /// Duplicate the current session at the current position.
567    CloneSession,
568    /// Navigate session tree (switch branches).
569    SessionTree,
570    /// Save project trust decision.
571    TrustDecision { decision: String },
572    /// Configure provider authentication.
573    Login {
574        provider: Option<String>,
575        api_key: Option<String>,
576    },
577    /// Remove provider authentication.
578    Logout { provider: Option<String> },
579    /// Manually compact the session context.
580    CompactSession(Option<String>),
581}
582
583/// A registered slash command.
584pub struct SlashCommand {
585    pub name: String,
586    pub description: String,
587    pub handler: Box<dyn CommandHandler>,
588}
589
590/// Simple cancellation token for tool execution.
591/// Shared between the agent loop and tool execution to signal cancellation.
592#[derive(Debug, Clone)]
593pub struct Cancel {
594    flag: Arc<AtomicBool>,
595}
596
597impl Cancel {
598    pub fn new() -> Self {
599        Self {
600            flag: Arc::new(AtomicBool::new(false)),
601        }
602    }
603
604    /// Check whether cancellation has been requested.
605    pub fn is_cancelled(&self) -> bool {
606        self.flag.load(Ordering::Relaxed)
607    }
608
609    /// Request cancellation.
610    pub fn cancel(&self) {
611        self.flag.store(true, Ordering::Relaxed);
612    }
613
614    /// Check if cancelled, returning an error if so.
615    pub fn check(&self) -> anyhow::Result<()> {
616        if self.is_cancelled() {
617            Err(anyhow::anyhow!("Operation cancelled"))
618        } else {
619            Ok(())
620        }
621    }
622}
623
624impl Default for Cancel {
625    fn default() -> Self {
626        Self::new()
627    }
628}
629
630/// Context passed to ToolRenderer methods (matching pi's ToolRenderContext).
631/// Carries all metadata about the tool execution that renderers may need.
632#[derive(Debug, Clone)]
633pub struct ToolRenderContext {
634    pub expanded: bool,
635    pub args_complete: bool,
636    pub is_partial: bool,
637    pub is_error: bool,
638    /// Unique id for this tool execution (pi's toolCallId).
639    pub tool_call_id: String,
640    /// Whether the tool execution has started (pi's executionStarted).
641    pub execution_started: bool,
642    /// Working directory for path resolution.
643    pub cwd: String,
644    /// Duration in seconds (bash).
645    pub duration_secs: Option<f64>,
646    /// Exit code (bash).
647    pub exit_code: Option<i32>,
648    /// Whether execution was cancelled (bash).
649    pub cancelled: bool,
650    /// Whether output was truncated (bash/read).
651    pub was_truncated: bool,
652    /// Path to full output file (bash).
653    pub full_output_path: Option<String>,
654    /// File path for syntax highlighting (read).
655    pub file_path: Option<String>,
656    /// Keybinding hint for the expand action, e.g. "C-O".
657    pub expand_key: String,
658    /// Structured rendering details from the tool execution (pi-compatible).
659    /// Set by tool renderers for preview/actual diff data. Not sent to the LLM.
660    pub details: Option<serde_json::Value>,
661    /// Shared mutable state per tool execution (pi's context.state).
662    /// Initialized as an empty JSON object `{}`. Renderers can mutate it
663    /// across renderCall/renderResult invocations for the same tool call.
664    pub state: std::rc::Rc<std::cell::RefCell<serde_json::Value>>,
665    /// Callback for renderers to request re-render (e.g. after async preview computation).
666    /// Pi-compatible: `context.invalidate()` in renderCall/renderResult.
667    /// Cloned from the original at context construction time.
668    /// Uses a channel sender internally to bridge from async to UI thread.
669    pub invalidate: Option<tokio::sync::mpsc::UnboundedSender<()>>,
670}
671
672/// Tool-specific rendering interface (matching pi's renderCall/renderResult pattern).
673/// Each built-in tool implements this to provide its own visual representation.
674pub trait ToolRenderer: Send + Sync {
675    /// Render the tool call header/title.
676    /// Returns ANSI-styled lines for the call portion (inside the colored box shell).
677    fn render_call(
678        &self,
679        args: &serde_json::Value,
680        width: usize,
681        theme: &dyn Theme,
682        ctx: &ToolRenderContext,
683    ) -> Vec<String>;
684
685    /// Render the tool result body.
686    /// Returns lines to display as the result body, or empty vec for no result.
687    /// When empty, only the call portion is shown (e.g. write success).
688    fn render_result(
689        &self,
690        content: &str,
691        width: usize,
692        theme: &dyn Theme,
693        ctx: &ToolRenderContext,
694    ) -> Vec<String>;
695
696    /// Whether this tool uses `renderShell: "self"` (controls its own framing).
697    /// When true, ToolExecComponent does NOT wrap the tool in a colored background box.
698    fn render_self(&self) -> bool {
699        false
700    }
701
702    /// Optional hint for the background color key when `render_self()` is false.
703    /// Return a theme key name (e.g. "toolPendingBg", "toolSuccessBg", "toolErrorBg")
704    /// to override the default background selection. Return None to let the
705    /// ToolExecComponent decide based on is_complete/is_error state.
706    /// Used by edit tool to show success/error bg during preview.
707    fn render_bg_key(&self) -> Option<&'static str> {
708        None
709    }
710}
711
712#[async_trait::async_trait]
713impl yoagent::types::AgentTool for ToolDefinition {
714    fn name(&self) -> &str {
715        self.tool.name()
716    }
717
718    fn label(&self) -> &str {
719        self.tool.label()
720    }
721
722    fn description(&self) -> &str {
723        self.tool.description()
724    }
725
726    fn parameters_schema(&self) -> serde_json::Value {
727        self.tool.parameters_schema()
728    }
729
730    async fn execute(
731        &self,
732        params: serde_json::Value,
733        ctx: yoagent::types::ToolContext,
734    ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
735        let mut params = match self.prepare_arguments {
736            Some(prepare) => prepare(params).map_err(yoagent::types::ToolError::InvalidArgs)?,
737            None => params,
738        };
739        // Step 1: type coercion (matching pi's Value.Convert + coerceWithJsonSchema)
740        let schema = self.tool.parameters_schema();
741        coerce_with_json_schema(&schema, &mut params);
742
743        // Step 2: validate against schema (matching pi's validateToolArguments)
744        let tool_name = self.tool.name();
745        validate_tool_arguments(tool_name, &schema, &params)
746            .map_err(yoagent::types::ToolError::InvalidArgs)?;
747
748        // Step 3: before_tool_call hook (matching pi's beforeToolCall)
749        if let Some(ref hook) = self.before_tool_call
750            && let Some(result) = hook(&params)
751            && result.block
752        {
753            let reason = if result.reason.is_empty() {
754                format!("Tool {} execution blocked", tool_name)
755            } else {
756                result.reason
757            };
758            return Err(yoagent::types::ToolError::Failed(reason));
759        }
760
761        // Step 4: execute the inner tool
762        let (mut tool_result, mut is_error) = match self.tool.execute(params, ctx).await {
763            Ok(r) => (r, false),
764            Err(e) => {
765                let err_text = e.to_string();
766                (
767                    yoagent::types::ToolResult {
768                        content: vec![yoagent::types::Content::Text { text: err_text }],
769                        details: serde_json::Value::Null,
770                    },
771                    true,
772                )
773            }
774        };
775
776        // Step 5: after_tool_call hook (matching pi's afterToolCall)
777        if let Some(ref hook) = self.after_tool_call
778            && let Some(override_result) = hook(&tool_result, is_error)
779        {
780            if let Some(content) = override_result.content {
781                tool_result.content = content;
782            }
783            if let Some(details) = override_result.details {
784                tool_result.details = details;
785            }
786            if let Some(err) = override_result.is_error {
787                is_error = err;
788            }
789        }
790
791        if is_error {
792            let error_text: String = tool_result
793                .content
794                .iter()
795                .filter_map(|c| {
796                    if let yoagent::types::Content::Text { text } = c {
797                        Some(text.as_str())
798                    } else {
799                        None
800                    }
801                })
802                .collect::<Vec<_>>()
803                .join("\n");
804            Err(yoagent::types::ToolError::Failed(error_text))
805        } else {
806            Ok(tool_result)
807        }
808    }
809}
810
811pub trait Extension: Send + Sync + std::any::Any {
812    fn name(&self) -> Cow<'static, str>;
813
814    /// Downcast to `&dyn Any` for downcasting to concrete types.
815    fn as_any(&self) -> &dyn std::any::Any;
816
817    /// Tools this extension provides (LLM-callable), each with its own prompt metadata.
818    fn tools(&self) -> Vec<ToolDefinition> {
819        vec![]
820    }
821
822    /// Slash commands this extension provides (e.g. `/quit`, `/model`).
823    /// Built-in commands and extension commands use the same interface.
824    fn commands(&self) -> Vec<SlashCommand> {
825        vec![]
826    }
827
828    /// Skills this extension provides (AgentSkills-compatible).
829    /// Merged into the session's skill set for /skill:name expansion and system prompt.
830    fn skills(&self) -> yoagent::skills::SkillSet {
831        yoagent::skills::SkillSet::empty()
832    }
833
834    /// Called before compaction runs (matching pi's `session_before_compact`).
835    /// Return `Some(BeforeCompactResult { cancel: true, .. })` to cancel compaction.
836    /// Return `Some(BeforeCompactResult { cancel: false, summary: Some(...), .. })`
837    /// to provide a custom summary instead of calling the provider.
838    /// Return `None` to let the default compaction proceed.
839    ///
840    /// `cancel` is a cancellation token — check `cancel.is_cancelled()` in
841    /// long-running hooks and return immediately if true (matching pi's
842    /// `AbortSignal` passed to `session_before_compact`).
843    fn before_compact(
844        &self,
845        _first_kept_entry_id: &str,
846        _tokens_before: u64,
847        _reason: &str,
848        _cancel: &Cancel,
849    ) -> Option<BeforeCompactResult> {
850        None
851    }
852
853    /// Called after compaction completes (matching pi's `session_compact`).
854    ///
855    /// `cancel` is a cancellation token — check `cancel.is_cancelled()` in
856    /// long-running hooks and return early if true (matching pi's
857    /// `AbortSignal` passed to `session_compact`).
858    #[allow(clippy::too_many_arguments)]
859    fn after_compact(
860        &self,
861        _summary: &str,
862        _first_kept_entry_id: &str,
863        _tokens_before: u64,
864        _estimated_tokens_after: u64,
865        _from_hook: bool,
866        _reason: &str,
867        _cancel: &Cancel,
868    ) {
869    }
870
871    /// Called when `/reload` is triggered (matching pi's `session_start` with reason "reload").
872    /// Extensions can refresh internal state, re-read configs, reconnect, etc.
873    fn on_reload(&self) {}
874
875    /// Called before the session is shut down or reloaded (matching pi's `session_shutdown`).
876    /// `reason` is "reload", "quit", "new", "resume", or "fork".
877    /// Extensions should save state, flush buffers, and prepare for teardown.
878    fn on_session_shutdown(&self, _reason: &str) {}
879
880    /// Called after the session starts or reloads (matching pi's `session_start`).
881    /// `reason` is "startup", "reload", "new", "resume", or "fork".
882    /// Extensions should restore state, re-register resources, reconnect.
883    fn on_session_start(&self, _reason: &str) {}
884}
885
886// ── Tests ──────────────────────────────────────────────────────────
887
888#[cfg(test)]
889mod tests {
890    use super::*;
891
892    // ── coerce_primitive_by_type ────────────────────────────────────
893
894    #[test]
895    fn test_coerce_string_from_number() {
896        let mut v = serde_json::json!(42);
897        coerce_primitive_by_type("string", &mut v);
898        assert_eq!(v, serde_json::json!("42"));
899    }
900
901    #[test]
902    fn test_coerce_string_from_boolean() {
903        let mut v = serde_json::json!(true);
904        coerce_primitive_by_type("string", &mut v);
905        assert_eq!(v, serde_json::json!("true"));
906    }
907
908    #[test]
909    fn test_coerce_string_from_null() {
910        let mut v = serde_json::json!(null);
911        coerce_primitive_by_type("string", &mut v);
912        assert_eq!(v, serde_json::json!(""));
913    }
914
915    #[test]
916    fn test_coerce_string_unchanged() {
917        let mut v = serde_json::json!("hello");
918        coerce_primitive_by_type("string", &mut v);
919        assert_eq!(v, serde_json::json!("hello"));
920    }
921
922    #[test]
923    fn test_coerce_number_from_string() {
924        let mut v = serde_json::json!("42.5");
925        coerce_primitive_by_type("number", &mut v);
926        assert_eq!(v, serde_json::json!(42.5));
927    }
928
929    #[test]
930    fn test_coerce_number_from_boolean() {
931        let mut v = serde_json::json!(true);
932        coerce_primitive_by_type("number", &mut v);
933        assert_eq!(v, serde_json::json!(1.0));
934    }
935
936    #[test]
937    fn test_coerce_number_from_null() {
938        let mut v = serde_json::json!(null);
939        coerce_primitive_by_type("number", &mut v);
940        assert_eq!(v, serde_json::json!(0.0));
941    }
942
943    #[test]
944    fn test_coerce_integer_from_string() {
945        let mut v = serde_json::json!("7");
946        coerce_primitive_by_type("integer", &mut v);
947        assert_eq!(v, serde_json::json!(7i64));
948    }
949
950    #[test]
951    fn test_coerce_integer_from_float() {
952        let mut v = serde_json::json!(3.9);
953        coerce_primitive_by_type("integer", &mut v);
954        assert_eq!(v, serde_json::json!(3i64));
955    }
956
957    #[test]
958    fn test_coerce_integer_from_boolean() {
959        let mut v = serde_json::json!(false);
960        coerce_primitive_by_type("integer", &mut v);
961        assert_eq!(v, serde_json::json!(0i64));
962    }
963
964    #[test]
965    fn test_coerce_boolean_from_string_true() {
966        let mut v = serde_json::json!("true");
967        coerce_primitive_by_type("boolean", &mut v);
968        assert_eq!(v, serde_json::json!(true));
969    }
970
971    #[test]
972    fn test_coerce_boolean_from_string_yes() {
973        let mut v = serde_json::json!("yes");
974        coerce_primitive_by_type("boolean", &mut v);
975        assert_eq!(v, serde_json::json!(true));
976    }
977
978    #[test]
979    fn test_coerce_boolean_from_number() {
980        let mut v = serde_json::json!(1);
981        coerce_primitive_by_type("boolean", &mut v);
982        assert_eq!(v, serde_json::json!(true));
983    }
984
985    #[test]
986    fn test_coerce_boolean_from_null() {
987        let mut v = serde_json::json!(null);
988        coerce_primitive_by_type("boolean", &mut v);
989        assert_eq!(v, serde_json::json!(false));
990    }
991
992    #[test]
993    fn test_coerce_array_from_scalar() {
994        let mut v = serde_json::json!("single");
995        coerce_primitive_by_type("array", &mut v);
996        assert_eq!(v, serde_json::json!(["single"]));
997    }
998
999    #[test]
1000    fn test_coerce_array_from_null() {
1001        let mut v = serde_json::json!(null);
1002        coerce_primitive_by_type("array", &mut v);
1003        assert_eq!(v, serde_json::json!([]));
1004    }
1005
1006    #[test]
1007    fn test_coerce_array_unchanged() {
1008        let mut v = serde_json::json!([1, 2, 3]);
1009        coerce_primitive_by_type("array", &mut v);
1010        assert_eq!(v, serde_json::json!([1, 2, 3]));
1011    }
1012
1013    #[test]
1014    fn test_coerce_unknown_type_does_nothing() {
1015        let mut v = serde_json::json!(42);
1016        coerce_primitive_by_type("widget", &mut v);
1017        assert_eq!(v, serde_json::json!(42));
1018    }
1019
1020    // ── coerce_with_json_schema ─────────────────────────────────────
1021
1022    #[test]
1023    fn test_coerce_schema_string_from_number() {
1024        let schema = serde_json::json!({
1025            "type": "object",
1026            "properties": {
1027                "name": {"type": "string"}
1028            }
1029        });
1030        let mut args = serde_json::json!({"name": 42});
1031        coerce_with_json_schema(&schema, &mut args);
1032        assert_eq!(args, serde_json::json!({"name": "42"}));
1033    }
1034
1035    #[test]
1036    fn test_coerce_schema_nested_object() {
1037        let schema = serde_json::json!({
1038            "type": "object",
1039            "properties": {
1040                "metadata": {
1041                    "type": "object",
1042                    "properties": {
1043                        "count": {"type": "integer"}
1044                    }
1045                }
1046            }
1047        });
1048        let mut args = serde_json::json!({"metadata": {"count": "5"}});
1049        coerce_with_json_schema(&schema, &mut args);
1050        assert_eq!(args, serde_json::json!({"metadata": {"count": 5i64}}));
1051    }
1052
1053    #[test]
1054    fn test_coerce_schema_array_items() {
1055        let schema = serde_json::json!({
1056            "type": "object",
1057            "properties": {
1058                "items": {
1059                    "type": "array",
1060                    "items": {
1061                        "type": "object",
1062                        "properties": {
1063                            "id": {"type": "integer"}
1064                        }
1065                    }
1066                }
1067            }
1068        });
1069        let mut args = serde_json::json!({"items": [{"id": "3"}, {"id": "7"}]});
1070        coerce_with_json_schema(&schema, &mut args);
1071        assert_eq!(
1072            args,
1073            serde_json::json!({"items": [{"id": 3i64}, {"id": 7i64}]})
1074        );
1075    }
1076
1077    #[test]
1078    fn test_coerce_schema_non_object_skipped() {
1079        let schema = serde_json::json!({"type": "string"});
1080        let mut args = serde_json::json!("hello");
1081        coerce_with_json_schema(&schema, &mut args);
1082        assert_eq!(args, serde_json::json!("hello"));
1083    }
1084
1085    // ── validate_tool_arguments ─────────────────────────────────────
1086
1087    #[test]
1088    fn test_validate_valid_args() {
1089        let schema = serde_json::json!({
1090            "type": "object",
1091            "properties": {
1092                "path": {"type": "string"}
1093            },
1094            "required": ["path"]
1095        });
1096        let args = serde_json::json!({"path": "/tmp/foo.txt"});
1097        assert!(validate_tool_arguments("test", &schema, &args).is_ok());
1098    }
1099
1100    #[test]
1101    fn test_validate_missing_required() {
1102        let schema = serde_json::json!({
1103            "type": "object",
1104            "properties": {
1105                "path": {"type": "string"}
1106            },
1107            "required": ["path"]
1108        });
1109        let args = serde_json::json!({});
1110        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1111        assert!(err.contains("Required"));
1112        assert!(err.contains("test"));
1113    }
1114
1115    #[test]
1116    fn test_validate_wrong_type() {
1117        let schema = serde_json::json!({
1118            "type": "object",
1119            "properties": {
1120                "count": {"type": "integer"}
1121            }
1122        });
1123        let args = serde_json::json!({"count": "not-a-number"});
1124        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1125        assert!(err.contains("Expected integer"));
1126    }
1127
1128    #[test]
1129    fn test_validate_additional_properties() {
1130        let schema = serde_json::json!({
1131            "type": "object",
1132            "properties": {
1133                "name": {"type": "string"}
1134            },
1135            "additionalProperties": false
1136        });
1137        let args = serde_json::json!({"name": "alice", "extra": "bad"});
1138        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1139        assert!(err.contains("must NOT have additional properties"));
1140    }
1141
1142    #[test]
1143    fn test_validate_not_an_object() {
1144        let schema = serde_json::json!({
1145            "type": "object",
1146            "properties": {}
1147        });
1148        let args = serde_json::json!("a string, not an object");
1149        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1150        assert!(err.contains("Expected object"));
1151    }
1152
1153    #[test]
1154    fn test_validate_array_item_types() {
1155        let schema = serde_json::json!({
1156            "type": "object",
1157            "properties": {
1158                "tags": {
1159                    "type": "array",
1160                    "items": {"type": "string"}
1161                }
1162            }
1163        });
1164        let args = serde_json::json!({"tags": [1, 2, 3]});
1165        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1166        assert!(err.contains("Expected string"));
1167    }
1168
1169    // ── Cancel ──────────────────────────────────────────────────────
1170
1171    #[test]
1172    fn test_cancel_new_not_cancelled() {
1173        let cancel = Cancel::new();
1174        assert!(!cancel.is_cancelled());
1175        cancel.check().unwrap();
1176    }
1177
1178    #[test]
1179    fn test_cancel_after_cancel() {
1180        let cancel = Cancel::new();
1181        cancel.cancel();
1182        assert!(cancel.is_cancelled());
1183        assert!(cancel.check().is_err());
1184    }
1185
1186    #[test]
1187    fn test_cancel_default_not_cancelled() {
1188        let cancel = Cancel::default();
1189        assert!(!cancel.is_cancelled());
1190    }
1191
1192    #[test]
1193    fn test_cancel_is_send_sync() {
1194        fn assert_send<T: Send>() {}
1195        fn assert_sync<T: Sync>() {}
1196        assert_send::<Cancel>();
1197        assert_sync::<Cancel>();
1198    }
1199}