Skip to main content

rab/agent/
extension.rs

1/// Extension trait - all capability (built-in or user-provided) comes through this.
2use crate::tui::Theme;
3use std::borrow::Cow;
4use std::sync::{
5    Arc,
6    atomic::{AtomicBool, Ordering},
7};
8
9// ── Tool call hooks (matching pi's beforeToolCall / afterToolCall) ──
10
11/// Result returned from `before_tool_call` (matching pi's `BeforeToolCallResult`).
12/// Returning `{ block: true }` prevents execution; `reason` becomes the error text.
13pub struct BeforeToolCallResult {
14    /// If true, the tool execution is blocked.
15    pub block: bool,
16    /// Error message shown when `block` is true. If empty, a default message is used.
17    pub reason: String,
18}
19
20/// Partial override returned from `after_tool_call` (matching pi's `AfterToolCallResult`).
21/// Merge semantics are field-by-field: provided fields replace the original; omitted fields keep their values.
22pub struct AfterToolCallResult {
23    /// If provided, replaces the tool result content array in full.
24    pub content: Option<Vec<yoagent::types::Content>>,
25    /// If provided, replaces the tool result details value in full.
26    pub details: Option<serde_json::Value>,
27    /// If provided, replaces the tool result error flag.
28    pub is_error: Option<bool>,
29}
30
31/// Result returned from `before_compact` (matching pi's `SessionBeforeCompactResult`).
32/// Returning `{ cancel: true }` prevents compaction.
33pub struct BeforeCompactResult {
34    /// If true, compaction is cancelled entirely.
35    pub cancel: bool,
36    /// If provided, uses this summary instead of calling the provider.
37    pub summary: Option<String>,
38    /// Optional details stored with the compaction entry.
39    pub details: Option<serde_json::Value>,
40}
41
42/// A tool bundled with its prompt metadata.
43///
44/// Mirrors pi's `ToolDefinition` which carries `promptSnippet`,
45/// `promptGuidelines` and `prepareArguments` directly on the tool definition.
46pub struct ToolDefinition {
47    pub tool: Box<dyn yoagent::types::AgentTool>,
48    /// One-line snippet for the "Available tools" section of the system prompt.
49    pub snippet: &'static str,
50    /// Guideline bullets for the "Guidelines" section of the system prompt.
51    pub guidelines: &'static [&'static str],
52    /// Optional pre-processing of raw LLM arguments before execute().
53    /// Receives raw arguments, returns normalized arguments or an error message.
54    pub prepare_arguments: Option<fn(serde_json::Value) -> Result<serde_json::Value, String>>,
55    /// Called before tool execution, after argument validation (matching pi's `beforeToolCall`).
56    /// Return `Some(BeforeToolCallResult { block: true, reason: "..." })` to block execution.
57    pub before_tool_call: Option<fn(&serde_json::Value) -> Option<BeforeToolCallResult>>,
58    /// Called after tool execution, before the result is returned (matching pi's `afterToolCall`).
59    pub after_tool_call:
60        Option<fn(&yoagent::types::ToolResult, bool) -> Option<AfterToolCallResult>>,
61    /// Tool-specific renderer for the TUI, bundled with the tool definition
62    /// (pi's renderCall/renderResult live on ToolDefinition).
63    pub renderer: Option<Arc<dyn ToolRenderer>>,
64}
65
66// ── Generic argument type coercion & validation ─────────────────
67
68/// Coerce a single JSON value to match a JSON Schema type (modifies in place).
69/// This handles common LLM mistakes: sending numbers as strings, booleans as strings, etc.
70pub fn coerce_primitive_by_type(schema_type: &str, value: &mut serde_json::Value) {
71    match schema_type {
72        "string" => {
73            if value.is_number() || value.is_boolean() {
74                *value = serde_json::Value::String(match value {
75                    serde_json::Value::Number(n) => n.to_string(),
76                    serde_json::Value::Bool(b) => b.to_string(),
77                    _ => unreachable!(),
78                });
79            } else if value.is_null() {
80                *value = serde_json::Value::String(String::new());
81            } else if value.is_array() || value.is_object() {
82                // TypeBox's Value.Convert stringifies arrays/objects when schema expects string
83                *value =
84                    serde_json::Value::String(serde_json::to_string(value).unwrap_or_default());
85            }
86        }
87        "number" => {
88            if let Some(s) = value.as_str() {
89                if let Ok(n) = s.parse::<f64>() {
90                    *value = serde_json::json!(n);
91                }
92            } else if value.is_boolean() {
93                *value = serde_json::json!(if value.as_bool().unwrap() { 1.0 } else { 0.0 });
94            } else if value.is_null() {
95                *value = serde_json::json!(0.0);
96            }
97        }
98        "integer" => {
99            if let Some(s) = value.as_str() {
100                if let Ok(n) = s.parse::<f64>() {
101                    *value = serde_json::json!(n as i64);
102                }
103            } else if value.is_boolean() {
104                *value = serde_json::json!(if value.as_bool().unwrap() { 1i64 } else { 0i64 });
105            } else if value.is_null() {
106                *value = serde_json::json!(0i64);
107            } else if let Some(n) = value.as_f64() {
108                *value = serde_json::json!(n as i64);
109            }
110        }
111        "boolean" => {
112            if let Some(s) = value.as_str() {
113                match s.trim().to_lowercase().as_str() {
114                    "true" | "1" | "yes" | "on" => *value = serde_json::Value::Bool(true),
115                    "false" | "0" | "no" | "off" => *value = serde_json::Value::Bool(false),
116                    _ => {} // Leave as-is if unrecognized
117                }
118            } else if value.is_number() {
119                *value = serde_json::Value::Bool(value.as_f64().unwrap_or(0.0) != 0.0);
120            } else if value.is_null() {
121                *value = serde_json::Value::Bool(false);
122            }
123        }
124        "null" => {
125            // Pi-compatible: treat empty string, 0, or false as null
126            if value.as_str().is_some_and(|s| s.is_empty())
127                || value.as_f64() == Some(0.0)
128                || value.as_bool() == Some(false)
129            {
130                *value = serde_json::Value::Null;
131            }
132        }
133        "array" => {
134            if !value.is_array() && !value.is_null() {
135                let v = std::mem::take(value);
136                *value = serde_json::Value::Array(vec![v]);
137            } else if value.is_null() {
138                *value = serde_json::Value::Array(vec![]);
139            }
140        }
141        _ => {}
142    }
143}
144
145/// Recursively coerce tool arguments to match a JSON Schema (modifies in place).
146pub fn coerce_with_json_schema(schema: &serde_json::Value, args: &mut serde_json::Value) {
147    // Handle composed schemas (matching pi's coerceWithJsonSchema order)
148    if let Some(all_of) = schema.get("allOf").and_then(|v| v.as_array()) {
149        for sub in all_of {
150            coerce_with_json_schema(sub, args);
151        }
152    }
153
154    if let Some(any_of) = schema.get("anyOf").and_then(|v| v.as_array()) {
155        // Try each anyOf alternative; keep the first that changes the value
156        if !any_of.is_empty() {
157            let original = args.clone();
158            for sub in any_of {
159                let mut candidate = original.clone();
160                coerce_with_json_schema(sub, &mut candidate);
161                if candidate != original {
162                    *args = candidate;
163                    break;
164                }
165            }
166        }
167    }
168
169    if let Some(one_of) = schema.get("oneOf").and_then(|v| v.as_array()) {
170        // Same strategy for oneOf
171        if !one_of.is_empty() {
172            let original = args.clone();
173            for sub in one_of {
174                let mut candidate = original.clone();
175                coerce_with_json_schema(sub, &mut candidate);
176                if candidate != original {
177                    *args = candidate;
178                    break;
179                }
180            }
181        }
182    }
183
184    if !args.is_object() {
185        return;
186    }
187    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
188        return;
189    };
190    for (key, prop_schema) in properties {
191        if args.get(key).is_none() {
192            continue;
193        }
194        let arg_value = args.get_mut(key).unwrap();
195
196        // Try each schema type in order (matching pi's approach of iterating types)
197        let schema_types = collect_schema_types(prop_schema);
198        if !schema_types.is_empty() {
199            // Check if value already matches one of the types
200            let already_matches = schema_types.iter().any(|t| matches_json_type(arg_value, t));
201            if !already_matches {
202                for st in &schema_types {
203                    let before = arg_value.clone();
204                    coerce_primitive_by_type(st, arg_value);
205                    if *arg_value != before {
206                        break;
207                    }
208                }
209            }
210
211            // Recurse into objects and arrays
212            if schema_types.iter().any(|t| t == "object") && arg_value.is_object() {
213                coerce_with_json_schema(prop_schema, arg_value);
214            }
215            if schema_types.iter().any(|t| t == "array")
216                && let Some(items_schema) = prop_schema.get("items")
217                && let Some(arr) = arg_value.as_array_mut()
218            {
219                for item in arr.iter_mut() {
220                    coerce_with_json_schema(items_schema, item);
221                }
222            }
223        }
224    }
225}
226
227/// Collect all type names from a schema property, handling both plain strings and arrays.
228fn collect_schema_types(schema: &serde_json::Value) -> Vec<String> {
229    let type_val = match schema.get("type") {
230        Some(t) => t,
231        None => return vec![],
232    };
233    if let Some(s) = type_val.as_str() {
234        return vec![s.to_string()];
235    }
236    if let Some(arr) = type_val.as_array() {
237        return arr
238            .iter()
239            .filter_map(|t| t.as_str().map(|s| s.to_string()))
240            .collect();
241    }
242    vec![]
243}
244
245// ── Schema validation (matching pi's validateToolArguments) ──────
246
247/// Extracts the effective JSON Schema type from a property schema.
248/// Returns `None` when the schema has no recognizable type.
249fn resolve_schema_type(schema: &serde_json::Value) -> Option<&str> {
250    let type_val = schema.get("type")?;
251    if type_val.is_string() {
252        return type_val.as_str();
253    }
254    if type_val.is_array() {
255        // Use the first non-null type (handles ["string", "null"])
256        // This is still used by validate_tool_arguments for single-type checks
257        return type_val
258            .as_array()
259            .and_then(|arr| arr.iter().find_map(|t| t.as_str().filter(|&s| s != "null")));
260    }
261    None
262}
263fn matches_json_type(value: &serde_json::Value, schema_type: &str) -> bool {
264    match schema_type {
265        "string" => value.is_string(),
266        "number" => value.is_number(),
267        "integer" => value.is_i64() || value.is_u64(),
268        "boolean" => value.is_boolean(),
269        "null" => value.is_null(),
270        "array" => value.is_array(),
271        "object" => value.is_object(),
272        _ => true, // unknown type — don't reject
273    }
274}
275
276/// Check whether a value matches at least one of the schema's types (handles ["string", "null"]).
277fn value_matches_schema_types(schema: &serde_json::Value, value: &serde_json::Value) -> bool {
278    let type_val = match schema.get("type") {
279        Some(t) => t,
280        None => return true,
281    };
282    if type_val.is_string() {
283        return matches_json_type(value, type_val.as_str().unwrap());
284    }
285    if let Some(types) = type_val.as_array() {
286        return types
287            .iter()
288            .filter_map(|t| t.as_str())
289            .any(|t| matches_json_type(value, t));
290    }
291    true
292}
293
294/// Recursively collect validation errors for a value against a JSON Schema.
295/// Path format matches pi's formatValidationPath: "root", "edits", "edits.0.oldText".
296fn collect_validation_errors(
297    schema: &serde_json::Value,
298    value: &serde_json::Value,
299    path: &str,
300    errors: &mut Vec<ValidationError>,
301) {
302    // Root must be an object — every tool schema is "type": "object"
303    if (path.is_empty() || path == "root")
304        && let Some(schema_type) = resolve_schema_type(schema)
305        && schema_type == "object"
306        && !value.is_object()
307    {
308        errors.push(ValidationError {
309            path: path.to_string(),
310            message: "Expected object".to_string(),
311        });
312        return;
313    }
314
315    // Not an object — only check type (won't recurse)
316    if !value.is_object()
317        && let Some(schema_type) = resolve_schema_type(schema)
318        && !matches_json_type(value, schema_type)
319    {
320        let expected = if schema_type == "integer" {
321            "integer"
322        } else {
323            schema_type
324        };
325        errors.push(ValidationError {
326            path: path.to_string(),
327            message: format!("Expected {}", expected),
328        });
329        return;
330    }
331
332    if !value.is_object() {
333        return;
334    }
335
336    let obj = value.as_object().unwrap();
337    let properties = schema.get("properties").and_then(|p| p.as_object());
338    let known_keys: std::collections::HashSet<&str> = properties
339        .map(|p| p.keys().map(|k| k.as_str()).collect())
340        .unwrap_or_default();
341
342    // Check required properties
343    if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
344        for required_val in required {
345            if let Some(required_key) = required_val.as_str()
346                && !obj.contains_key(required_key)
347            {
348                let err_path = if path.is_empty() || path == "root" {
349                    required_key.to_string()
350                } else {
351                    format!("{}.{}", path, required_key)
352                };
353                errors.push(ValidationError {
354                    path: err_path,
355                    message: "Required".to_string(),
356                });
357            }
358        }
359    }
360
361    // Check additionalProperties
362    if schema.get("additionalProperties") == Some(&serde_json::Value::Bool(false)) {
363        for key in obj.keys() {
364            if !known_keys.contains(key.as_str()) {
365                let err_path = if path.is_empty() || path == "root" {
366                    key.clone()
367                } else {
368                    format!("{}.{}", path, key)
369                };
370                errors.push(ValidationError {
371                    path: err_path,
372                    message: "must NOT have additional properties".to_string(),
373                });
374            }
375        }
376    }
377
378    // Validate each property
379    if let Some(props) = properties {
380        for (key, prop_schema) in props {
381            if let Some(val) = value.get(key) {
382                let child_path = if path.is_empty() || path == "root" {
383                    key.clone()
384                } else {
385                    format!("{}.{}", path, key)
386                };
387                validate_property(prop_schema, val, &child_path, errors);
388            }
389        }
390    }
391}
392
393/// Validate a single property value against its schema, recursing into objects/arrays.
394fn validate_property(
395    schema: &serde_json::Value,
396    value: &serde_json::Value,
397    path: &str,
398    errors: &mut Vec<ValidationError>,
399) {
400    // Check type match
401    if !value_matches_schema_types(schema, value) {
402        let schema_type = resolve_schema_type(schema).unwrap_or("unknown");
403        let expected = if schema_type == "integer" {
404            "integer"
405        } else {
406            schema_type
407        };
408        errors.push(ValidationError {
409            path: path.to_string(),
410            message: format!("Expected {}", expected),
411        });
412        return; // Don't recurse into wrong-typed values
413    }
414
415    // Recurse into objects
416    if value.is_object() {
417        // Only recurse if the schema also describes an object
418        let schema_type = resolve_schema_type(schema);
419        if schema_type == Some("object") {
420            collect_validation_errors(schema, value, path, errors);
421        }
422        return;
423    }
424
425    // Recurse into array items
426    if let Some(arr) = value.as_array()
427        && resolve_schema_type(schema) == Some("array")
428        && let Some(items_schema) = schema.get("items")
429    {
430        for (i, item) in arr.iter().enumerate() {
431            let item_path = format!("{}.{}", path, i);
432            validate_property(items_schema, item, &item_path, errors);
433        }
434    }
435}
436
437/// A single validation error, matching pi's TypeBox error structure.
438#[derive(Debug, Clone)]
439pub struct ValidationError {
440    /// Path to the field, e.g. "edits.0.oldText" or "root"
441    pub path: String,
442    /// Error message, e.g. "Required" or "must NOT have additional properties"
443    pub message: String,
444}
445
446/// Validate tool arguments against its JSON Schema (matching pi's validateToolArguments).
447///
448/// Returns `Ok(())` on success, or `Err` with pi-compatible format:
449/// ```text
450/// Validation failed for tool "edit":
451///   - path: Required
452///   - edits[0].oldText: Required
453///
454/// Received arguments:
455/// {
456///   "path": "/foo.txt"
457/// }
458/// ```
459pub fn validate_tool_arguments(
460    tool_name: &str,
461    schema: &serde_json::Value,
462    args: &serde_json::Value,
463) -> Result<(), String> {
464    let mut errors: Vec<ValidationError> = Vec::new();
465    collect_validation_errors(schema, args, "root", &mut errors);
466
467    if errors.is_empty() {
468        return Ok(());
469    }
470
471    let error_lines: Vec<String> = errors
472        .iter()
473        .map(|e| format!("  - {}: {}", e.path, e.message))
474        .collect();
475
476    let pretty_args =
477        serde_json::to_string_pretty(args).unwrap_or_else(|_| "<unprintable>".to_string());
478
479    Err(format!(
480        "Validation failed for tool \"{tool_name}\":\n{}\n\nReceived arguments:\n{pretty_args}",
481        error_lines.join("\n"),
482    ))
483}
484
485/// An autocomplete item for slash command arguments.
486#[derive(Debug, Clone)]
487pub struct AutocompleteItem {
488    /// The value to insert when selected.
489    pub value: String,
490    /// Display label.
491    pub label: String,
492    /// Optional description.
493    pub description: Option<String>,
494}
495
496/// A slash command handler (built-in or extension-provided).
497/// Commands use the same Extension trait as tools - built-ins and
498/// user extensions register commands through a uniform interface.
499pub trait CommandHandler: Send + Sync {
500    /// Execute the command with the given arguments string.
501    fn execute(&self, args: &str) -> anyhow::Result<CommandResult>;
502
503    /// Get argument completions for autocomplete.
504    /// Called when user types `/cmd ` - returns matching autocomplete items.
505    fn argument_completions(&self, _prefix: &str) -> Vec<AutocompleteItem> {
506        vec![]
507    }
508}
509
510/// Result of executing a slash command.
511#[derive(Debug, Clone)]
512pub enum CommandResult {
513    /// Command handled, show this info message.
514    Info(String),
515    /// Command caused a quit request.
516    Quit,
517    /// Command switched the model (new model name).
518    ModelChanged(String),
519    /// Show keyboard shortcuts help overlay.
520    ShowHelp,
521    /// Reload settings, extensions, keybindings, themes from disk.
522    Reloaded,
523    /// Start a new session (clear conversation).
524    NewSession,
525    /// Switch to a different session file.
526    SessionSwitched { path: std::path::PathBuf },
527    /// Show session info (ID, file, messages, tokens, cost).
528    SessionInfo {
529        session_id: String,
530        file_path: Option<std::path::PathBuf>,
531        name: Option<String>,
532        message_count: usize,
533        user_messages: usize,
534        assistant_messages: usize,
535        tool_calls: usize,
536        tool_results: usize,
537        total_tokens: u64,
538        input_tokens: u64,
539        output_tokens: u64,
540        cache_read_tokens: u64,
541        cache_write_tokens: u64,
542        cost: f64,
543    },
544    /// Open session selector UI.
545    OpenSessionSelector,
546    /// Name was set for the session.
547    SessionNamed { name: String },
548    /// Open settings menu.
549    OpenSettings,
550    /// Enable/disable models for cycling.
551    ScopedModels,
552    /// Export session (HTML default, or specify path).
553    ExportSession { path: Option<String> },
554    /// Import and resume a session from a JSONL file.
555    ImportSession { path: String },
556    /// Share session as a secret GitHub gist.
557    ShareSession,
558    /// Copy last agent message to clipboard.
559    CopyLastMessage,
560    /// Show changelog entries.
561    ShowChangelog,
562    /// Create a new fork from a previous user message.
563    ForkSession { message_id: Option<String> },
564    /// Duplicate the current session at the current position.
565    CloneSession,
566    /// Navigate session tree (switch branches).
567    SessionTree,
568    /// Save project trust decision.
569    TrustDecision { decision: String },
570    /// Configure provider authentication.
571    Login { provider: Option<String> },
572    /// Remove provider authentication.
573    Logout { provider: Option<String> },
574    /// Manually compact the session context.
575    CompactSession(Option<String>),
576}
577
578/// A registered slash command.
579pub struct SlashCommand {
580    pub name: String,
581    pub description: String,
582    pub handler: Box<dyn CommandHandler>,
583}
584
585/// Simple cancellation token for tool execution.
586/// Shared between the agent loop and tool execution to signal cancellation.
587#[derive(Debug, Clone)]
588pub struct Cancel {
589    flag: Arc<AtomicBool>,
590}
591
592impl Cancel {
593    pub fn new() -> Self {
594        Self {
595            flag: Arc::new(AtomicBool::new(false)),
596        }
597    }
598
599    /// Check whether cancellation has been requested.
600    pub fn is_cancelled(&self) -> bool {
601        self.flag.load(Ordering::Relaxed)
602    }
603
604    /// Request cancellation.
605    pub fn cancel(&self) {
606        self.flag.store(true, Ordering::Relaxed);
607    }
608
609    /// Check if cancelled, returning an error if so.
610    pub fn check(&self) -> anyhow::Result<()> {
611        if self.is_cancelled() {
612            Err(anyhow::anyhow!("Operation cancelled"))
613        } else {
614            Ok(())
615        }
616    }
617}
618
619impl Default for Cancel {
620    fn default() -> Self {
621        Self::new()
622    }
623}
624
625/// Context passed to ToolRenderer methods (matching pi's ToolRenderContext).
626/// Carries all metadata about the tool execution that renderers may need.
627#[derive(Debug, Clone)]
628pub struct ToolRenderContext {
629    pub expanded: bool,
630    pub args_complete: bool,
631    pub is_partial: bool,
632    pub is_error: bool,
633    /// Unique id for this tool execution (pi's toolCallId).
634    pub tool_call_id: String,
635    /// Whether the tool execution has started (pi's executionStarted).
636    pub execution_started: bool,
637    /// Working directory for path resolution.
638    pub cwd: String,
639    /// Duration in seconds (bash).
640    pub duration_secs: Option<f64>,
641    /// Exit code (bash).
642    pub exit_code: Option<i32>,
643    /// Whether execution was cancelled (bash).
644    pub cancelled: bool,
645    /// Whether output was truncated (bash/read).
646    pub was_truncated: bool,
647    /// Path to full output file (bash).
648    pub full_output_path: Option<String>,
649    /// File path for syntax highlighting (read).
650    pub file_path: Option<String>,
651    /// Keybinding hint for the expand action, e.g. "C-O".
652    pub expand_key: String,
653    /// Structured rendering details from the tool execution (pi-compatible).
654    /// Set by tool renderers for preview/actual diff data. Not sent to the LLM.
655    pub details: Option<serde_json::Value>,
656    /// Shared mutable state per tool execution (pi's context.state).
657    /// Initialized as an empty JSON object `{}`. Renderers can mutate it
658    /// across renderCall/renderResult invocations for the same tool call.
659    pub state: std::rc::Rc<std::cell::RefCell<serde_json::Value>>,
660    /// Callback for renderers to request re-render (e.g. after async preview computation).
661    /// Pi-compatible: `context.invalidate()` in renderCall/renderResult.
662    /// Cloned from the original at context construction time.
663    /// Uses a channel sender internally to bridge from async to UI thread.
664    pub invalidate: Option<tokio::sync::mpsc::UnboundedSender<()>>,
665}
666
667/// Tool-specific rendering interface (matching pi's renderCall/renderResult pattern).
668/// Each built-in tool implements this to provide its own visual representation.
669pub trait ToolRenderer: Send + Sync {
670    /// Render the tool call header/title.
671    /// Returns ANSI-styled lines for the call portion (inside the colored box shell).
672    fn render_call(
673        &self,
674        args: &serde_json::Value,
675        width: usize,
676        theme: &dyn Theme,
677        ctx: &ToolRenderContext,
678    ) -> Vec<String>;
679
680    /// Render the tool result body.
681    /// Returns lines to display as the result body, or empty vec for no result.
682    /// When empty, only the call portion is shown (e.g. write success).
683    fn render_result(
684        &self,
685        content: &str,
686        width: usize,
687        theme: &dyn Theme,
688        ctx: &ToolRenderContext,
689    ) -> Vec<String>;
690
691    /// Whether this tool uses `renderShell: "self"` (controls its own framing).
692    /// When true, ToolExecComponent does NOT wrap the tool in a colored background box.
693    fn render_self(&self) -> bool {
694        false
695    }
696
697    /// Optional hint for the background color key when `render_self()` is false.
698    /// Return a theme key name (e.g. "toolPendingBg", "toolSuccessBg", "toolErrorBg")
699    /// to override the default background selection. Return None to let the
700    /// ToolExecComponent decide based on is_complete/is_error state.
701    /// Used by edit tool to show success/error bg during preview.
702    fn render_bg_key(&self) -> Option<&'static str> {
703        None
704    }
705}
706
707#[async_trait::async_trait]
708impl yoagent::types::AgentTool for ToolDefinition {
709    fn name(&self) -> &str {
710        self.tool.name()
711    }
712
713    fn label(&self) -> &str {
714        self.tool.label()
715    }
716
717    fn description(&self) -> &str {
718        self.tool.description()
719    }
720
721    fn parameters_schema(&self) -> serde_json::Value {
722        self.tool.parameters_schema()
723    }
724
725    async fn execute(
726        &self,
727        params: serde_json::Value,
728        ctx: yoagent::types::ToolContext,
729    ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
730        let mut params = match self.prepare_arguments {
731            Some(prepare) => prepare(params).map_err(yoagent::types::ToolError::InvalidArgs)?,
732            None => params,
733        };
734        // Step 1: type coercion (matching pi's Value.Convert + coerceWithJsonSchema)
735        let schema = self.tool.parameters_schema();
736        coerce_with_json_schema(&schema, &mut params);
737
738        // Step 2: validate against schema (matching pi's validateToolArguments)
739        let tool_name = self.tool.name();
740        validate_tool_arguments(tool_name, &schema, &params)
741            .map_err(yoagent::types::ToolError::InvalidArgs)?;
742
743        // Step 3: before_tool_call hook (matching pi's beforeToolCall)
744        if let Some(ref hook) = self.before_tool_call
745            && let Some(result) = hook(&params)
746            && result.block
747        {
748            let reason = if result.reason.is_empty() {
749                format!("Tool {} execution blocked", tool_name)
750            } else {
751                result.reason
752            };
753            return Err(yoagent::types::ToolError::Failed(reason));
754        }
755
756        // Step 4: execute the inner tool
757        let (mut tool_result, mut is_error) = match self.tool.execute(params, ctx).await {
758            Ok(r) => (r, false),
759            Err(e) => {
760                let err_text = e.to_string();
761                (
762                    yoagent::types::ToolResult {
763                        content: vec![yoagent::types::Content::Text { text: err_text }],
764                        details: serde_json::Value::Null,
765                    },
766                    true,
767                )
768            }
769        };
770
771        // Step 5: after_tool_call hook (matching pi's afterToolCall)
772        if let Some(ref hook) = self.after_tool_call
773            && let Some(override_result) = hook(&tool_result, is_error)
774        {
775            if let Some(content) = override_result.content {
776                tool_result.content = content;
777            }
778            if let Some(details) = override_result.details {
779                tool_result.details = details;
780            }
781            if let Some(err) = override_result.is_error {
782                is_error = err;
783            }
784        }
785
786        if is_error {
787            let error_text: String = tool_result
788                .content
789                .iter()
790                .filter_map(|c| {
791                    if let yoagent::types::Content::Text { text } = c {
792                        Some(text.as_str())
793                    } else {
794                        None
795                    }
796                })
797                .collect::<Vec<_>>()
798                .join("\n");
799            Err(yoagent::types::ToolError::Failed(error_text))
800        } else {
801            Ok(tool_result)
802        }
803    }
804}
805
806pub trait Extension: Send + Sync {
807    fn name(&self) -> Cow<'static, str>;
808
809    /// Tools this extension provides (LLM-callable), each with its own prompt metadata.
810    fn tools(&self) -> Vec<ToolDefinition> {
811        vec![]
812    }
813
814    /// Slash commands this extension provides (e.g. `/quit`, `/model`).
815    /// Built-in commands and extension commands use the same interface.
816    fn commands(&self) -> Vec<SlashCommand> {
817        vec![]
818    }
819
820    /// Skills this extension provides (AgentSkills-compatible).
821    /// Merged into the session's skill set for /skill:name expansion and system prompt.
822    fn skills(&self) -> yoagent::skills::SkillSet {
823        yoagent::skills::SkillSet::empty()
824    }
825
826    /// Called before compaction runs (matching pi's `session_before_compact`).
827    /// Return `Some(BeforeCompactResult { cancel: true, .. })` to cancel compaction.
828    /// Return `Some(BeforeCompactResult { cancel: false, summary: Some(...), .. })`
829    /// to provide a custom summary instead of calling the provider.
830    /// Return `None` to let the default compaction proceed.
831    ///
832    /// `cancel` is a cancellation token — check `cancel.is_cancelled()` in
833    /// long-running hooks and return immediately if true (matching pi's
834    /// `AbortSignal` passed to `session_before_compact`).
835    fn before_compact(
836        &self,
837        _first_kept_entry_id: &str,
838        _tokens_before: u64,
839        _reason: &str,
840        _cancel: &Cancel,
841    ) -> Option<BeforeCompactResult> {
842        None
843    }
844
845    /// Called after compaction completes (matching pi's `session_compact`).
846    ///
847    /// `cancel` is a cancellation token — check `cancel.is_cancelled()` in
848    /// long-running hooks and return early if true (matching pi's
849    /// `AbortSignal` passed to `session_compact`).
850    #[allow(clippy::too_many_arguments)]
851    fn after_compact(
852        &self,
853        _summary: &str,
854        _first_kept_entry_id: &str,
855        _tokens_before: u64,
856        _estimated_tokens_after: u64,
857        _from_hook: bool,
858        _reason: &str,
859        _cancel: &Cancel,
860    ) {
861    }
862}
863
864// ── Tests ──────────────────────────────────────────────────────────
865
866#[cfg(test)]
867mod tests {
868    use super::*;
869
870    // ── coerce_primitive_by_type ────────────────────────────────────
871
872    #[test]
873    fn test_coerce_string_from_number() {
874        let mut v = serde_json::json!(42);
875        coerce_primitive_by_type("string", &mut v);
876        assert_eq!(v, serde_json::json!("42"));
877    }
878
879    #[test]
880    fn test_coerce_string_from_boolean() {
881        let mut v = serde_json::json!(true);
882        coerce_primitive_by_type("string", &mut v);
883        assert_eq!(v, serde_json::json!("true"));
884    }
885
886    #[test]
887    fn test_coerce_string_from_null() {
888        let mut v = serde_json::json!(null);
889        coerce_primitive_by_type("string", &mut v);
890        assert_eq!(v, serde_json::json!(""));
891    }
892
893    #[test]
894    fn test_coerce_string_unchanged() {
895        let mut v = serde_json::json!("hello");
896        coerce_primitive_by_type("string", &mut v);
897        assert_eq!(v, serde_json::json!("hello"));
898    }
899
900    #[test]
901    fn test_coerce_number_from_string() {
902        let mut v = serde_json::json!("42.5");
903        coerce_primitive_by_type("number", &mut v);
904        assert_eq!(v, serde_json::json!(42.5));
905    }
906
907    #[test]
908    fn test_coerce_number_from_boolean() {
909        let mut v = serde_json::json!(true);
910        coerce_primitive_by_type("number", &mut v);
911        assert_eq!(v, serde_json::json!(1.0));
912    }
913
914    #[test]
915    fn test_coerce_number_from_null() {
916        let mut v = serde_json::json!(null);
917        coerce_primitive_by_type("number", &mut v);
918        assert_eq!(v, serde_json::json!(0.0));
919    }
920
921    #[test]
922    fn test_coerce_integer_from_string() {
923        let mut v = serde_json::json!("7");
924        coerce_primitive_by_type("integer", &mut v);
925        assert_eq!(v, serde_json::json!(7i64));
926    }
927
928    #[test]
929    fn test_coerce_integer_from_float() {
930        let mut v = serde_json::json!(3.9);
931        coerce_primitive_by_type("integer", &mut v);
932        assert_eq!(v, serde_json::json!(3i64));
933    }
934
935    #[test]
936    fn test_coerce_integer_from_boolean() {
937        let mut v = serde_json::json!(false);
938        coerce_primitive_by_type("integer", &mut v);
939        assert_eq!(v, serde_json::json!(0i64));
940    }
941
942    #[test]
943    fn test_coerce_boolean_from_string_true() {
944        let mut v = serde_json::json!("true");
945        coerce_primitive_by_type("boolean", &mut v);
946        assert_eq!(v, serde_json::json!(true));
947    }
948
949    #[test]
950    fn test_coerce_boolean_from_string_yes() {
951        let mut v = serde_json::json!("yes");
952        coerce_primitive_by_type("boolean", &mut v);
953        assert_eq!(v, serde_json::json!(true));
954    }
955
956    #[test]
957    fn test_coerce_boolean_from_number() {
958        let mut v = serde_json::json!(1);
959        coerce_primitive_by_type("boolean", &mut v);
960        assert_eq!(v, serde_json::json!(true));
961    }
962
963    #[test]
964    fn test_coerce_boolean_from_null() {
965        let mut v = serde_json::json!(null);
966        coerce_primitive_by_type("boolean", &mut v);
967        assert_eq!(v, serde_json::json!(false));
968    }
969
970    #[test]
971    fn test_coerce_array_from_scalar() {
972        let mut v = serde_json::json!("single");
973        coerce_primitive_by_type("array", &mut v);
974        assert_eq!(v, serde_json::json!(["single"]));
975    }
976
977    #[test]
978    fn test_coerce_array_from_null() {
979        let mut v = serde_json::json!(null);
980        coerce_primitive_by_type("array", &mut v);
981        assert_eq!(v, serde_json::json!([]));
982    }
983
984    #[test]
985    fn test_coerce_array_unchanged() {
986        let mut v = serde_json::json!([1, 2, 3]);
987        coerce_primitive_by_type("array", &mut v);
988        assert_eq!(v, serde_json::json!([1, 2, 3]));
989    }
990
991    #[test]
992    fn test_coerce_unknown_type_does_nothing() {
993        let mut v = serde_json::json!(42);
994        coerce_primitive_by_type("widget", &mut v);
995        assert_eq!(v, serde_json::json!(42));
996    }
997
998    // ── coerce_with_json_schema ─────────────────────────────────────
999
1000    #[test]
1001    fn test_coerce_schema_string_from_number() {
1002        let schema = serde_json::json!({
1003            "type": "object",
1004            "properties": {
1005                "name": {"type": "string"}
1006            }
1007        });
1008        let mut args = serde_json::json!({"name": 42});
1009        coerce_with_json_schema(&schema, &mut args);
1010        assert_eq!(args, serde_json::json!({"name": "42"}));
1011    }
1012
1013    #[test]
1014    fn test_coerce_schema_nested_object() {
1015        let schema = serde_json::json!({
1016            "type": "object",
1017            "properties": {
1018                "metadata": {
1019                    "type": "object",
1020                    "properties": {
1021                        "count": {"type": "integer"}
1022                    }
1023                }
1024            }
1025        });
1026        let mut args = serde_json::json!({"metadata": {"count": "5"}});
1027        coerce_with_json_schema(&schema, &mut args);
1028        assert_eq!(args, serde_json::json!({"metadata": {"count": 5i64}}));
1029    }
1030
1031    #[test]
1032    fn test_coerce_schema_array_items() {
1033        let schema = serde_json::json!({
1034            "type": "object",
1035            "properties": {
1036                "items": {
1037                    "type": "array",
1038                    "items": {
1039                        "type": "object",
1040                        "properties": {
1041                            "id": {"type": "integer"}
1042                        }
1043                    }
1044                }
1045            }
1046        });
1047        let mut args = serde_json::json!({"items": [{"id": "3"}, {"id": "7"}]});
1048        coerce_with_json_schema(&schema, &mut args);
1049        assert_eq!(
1050            args,
1051            serde_json::json!({"items": [{"id": 3i64}, {"id": 7i64}]})
1052        );
1053    }
1054
1055    #[test]
1056    fn test_coerce_schema_non_object_skipped() {
1057        let schema = serde_json::json!({"type": "string"});
1058        let mut args = serde_json::json!("hello");
1059        coerce_with_json_schema(&schema, &mut args);
1060        assert_eq!(args, serde_json::json!("hello"));
1061    }
1062
1063    // ── validate_tool_arguments ─────────────────────────────────────
1064
1065    #[test]
1066    fn test_validate_valid_args() {
1067        let schema = serde_json::json!({
1068            "type": "object",
1069            "properties": {
1070                "path": {"type": "string"}
1071            },
1072            "required": ["path"]
1073        });
1074        let args = serde_json::json!({"path": "/tmp/foo.txt"});
1075        assert!(validate_tool_arguments("test", &schema, &args).is_ok());
1076    }
1077
1078    #[test]
1079    fn test_validate_missing_required() {
1080        let schema = serde_json::json!({
1081            "type": "object",
1082            "properties": {
1083                "path": {"type": "string"}
1084            },
1085            "required": ["path"]
1086        });
1087        let args = serde_json::json!({});
1088        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1089        assert!(err.contains("Required"));
1090        assert!(err.contains("test"));
1091    }
1092
1093    #[test]
1094    fn test_validate_wrong_type() {
1095        let schema = serde_json::json!({
1096            "type": "object",
1097            "properties": {
1098                "count": {"type": "integer"}
1099            }
1100        });
1101        let args = serde_json::json!({"count": "not-a-number"});
1102        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1103        assert!(err.contains("Expected integer"));
1104    }
1105
1106    #[test]
1107    fn test_validate_additional_properties() {
1108        let schema = serde_json::json!({
1109            "type": "object",
1110            "properties": {
1111                "name": {"type": "string"}
1112            },
1113            "additionalProperties": false
1114        });
1115        let args = serde_json::json!({"name": "alice", "extra": "bad"});
1116        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1117        assert!(err.contains("must NOT have additional properties"));
1118    }
1119
1120    #[test]
1121    fn test_validate_not_an_object() {
1122        let schema = serde_json::json!({
1123            "type": "object",
1124            "properties": {}
1125        });
1126        let args = serde_json::json!("a string, not an object");
1127        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1128        assert!(err.contains("Expected object"));
1129    }
1130
1131    #[test]
1132    fn test_validate_array_item_types() {
1133        let schema = serde_json::json!({
1134            "type": "object",
1135            "properties": {
1136                "tags": {
1137                    "type": "array",
1138                    "items": {"type": "string"}
1139                }
1140            }
1141        });
1142        let args = serde_json::json!({"tags": [1, 2, 3]});
1143        let err = validate_tool_arguments("test", &schema, &args).unwrap_err();
1144        assert!(err.contains("Expected string"));
1145    }
1146
1147    // ── Cancel ──────────────────────────────────────────────────────
1148
1149    #[test]
1150    fn test_cancel_new_not_cancelled() {
1151        let cancel = Cancel::new();
1152        assert!(!cancel.is_cancelled());
1153        cancel.check().unwrap();
1154    }
1155
1156    #[test]
1157    fn test_cancel_after_cancel() {
1158        let cancel = Cancel::new();
1159        cancel.cancel();
1160        assert!(cancel.is_cancelled());
1161        assert!(cancel.check().is_err());
1162    }
1163
1164    #[test]
1165    fn test_cancel_default_not_cancelled() {
1166        let cancel = Cancel::default();
1167        assert!(!cancel.is_cancelled());
1168    }
1169
1170    #[test]
1171    fn test_cancel_is_send_sync() {
1172        fn assert_send<T: Send>() {}
1173        fn assert_sync<T: Sync>() {}
1174        assert_send::<Cancel>();
1175        assert_sync::<Cancel>();
1176    }
1177}