use crate::config::constants::tools;
use hashbrown::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub use vtcode_utility_tool_specs::{
AdditionalProperties, FreeformTool, FreeformToolFormat, JsonSchema, ResponsesApiTool,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ToolKind {
Function,
Mcp,
Custom,
}
#[derive(Clone, Debug)]
pub enum ToolPayload {
Function { arguments: String },
Custom { input: String },
Mcp { arguments: Option<Value> },
LocalShell { params: ShellToolCallParams },
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ShellToolCallParams {
pub command: Vec<String>,
pub workdir: Option<String>,
pub timeout_ms: Option<u64>,
pub sandbox_permissions: Option<SandboxPermissions>,
pub justification: Option<String>,
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SandboxPermissions {
#[default]
UseDefault,
RequireEscalated,
WithAdditionalPermissions,
}
#[derive(Clone, Debug)]
pub enum ToolOutput {
Function {
content: String,
content_items: Option<Vec<ContentItem>>,
success: Option<bool>,
},
Mcp { result: McpToolResult },
}
impl ToolOutput {
pub fn simple(content: impl Into<String>) -> Self {
Self::Function {
content: content.into(),
content_items: None,
success: Some(true),
}
}
pub fn with_success(content: impl Into<String>, success: bool) -> Self {
Self::Function {
content: content.into(),
content_items: None,
success: Some(success),
}
}
pub fn error(message: impl Into<String>) -> Self {
Self::Function {
content: message.into(),
content_items: None,
success: Some(false),
}
}
pub fn content(&self) -> Option<&str> {
match self {
Self::Function { content, .. } => Some(content),
Self::Mcp { result } => result.content.first().and_then(|c| c.as_text()),
}
}
pub fn is_success(&self) -> bool {
match self {
Self::Function { success, .. } => success.unwrap_or(true),
Self::Mcp { result } => !result.is_error.unwrap_or(false),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentItem {
Text {
text: String,
},
Image {
data: String,
mime_type: String,
},
Resource {
uri: String,
mime_type: Option<String>,
},
}
impl ContentItem {
pub fn as_text(&self) -> Option<&str> {
match self {
ContentItem::Text { text } => Some(text),
_ => None,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct McpToolResult {
pub content: Vec<ContentItem>,
pub is_error: Option<bool>,
}
pub struct ToolInvocation {
pub session: Arc<dyn ToolSession>,
pub turn: Arc<TurnContext>,
pub tracker: Option<SharedDiffTracker>,
pub call_id: String,
pub tool_name: String,
pub payload: ToolPayload,
}
pub type SharedDiffTracker = Arc<tokio::sync::Mutex<DiffTracker>>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Constrained<T> {
value: T,
}
impl<T> Constrained<T> {
pub fn allow_any(initial_value: T) -> Self {
Self {
value: initial_value,
}
}
pub fn get(&self) -> &T {
&self.value
}
}
impl<T: Copy> Constrained<T> {
pub fn value(&self) -> T {
self.value
}
}
impl<T: Default> Default for Constrained<T> {
fn default() -> Self {
Self::allow_any(T::default())
}
}
#[async_trait]
pub trait ToolSession: Send + Sync {
fn cwd(&self) -> &PathBuf;
fn workspace_root(&self) -> &PathBuf;
async fn record_warning(&self, message: String);
fn user_shell(&self) -> &str;
async fn send_event(&self, event: ToolEvent);
}
#[derive(Clone, Debug)]
pub struct TurnContext {
pub cwd: PathBuf,
pub turn_id: String,
pub sub_id: Option<String>,
pub shell_environment_policy: ShellEnvironmentPolicy,
pub approval_policy: Constrained<ApprovalPolicy>,
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub sandbox_policy: Constrained<super::sandboxing::SandboxPolicy>,
}
impl TurnContext {
pub fn resolve_path(&self, path: Option<String>) -> PathBuf {
self.resolve_path_ref(path.as_deref())
}
pub fn resolve_path_ref(&self, path: Option<&str>) -> PathBuf {
match path {
Some(p) => {
let path = PathBuf::from(p);
if path.is_absolute() {
path
} else {
self.cwd.join(path)
}
}
None => self.cwd.clone(),
}
}
}
#[derive(Clone, Debug, Default)]
pub enum ShellEnvironmentPolicy {
#[default]
Inherit,
Clean,
Custom(HashMap<String, String>),
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum ApprovalPolicy {
#[default]
Never,
OnMutation,
Always,
}
#[derive(Default, Debug)]
pub struct DiffTracker {
pub changes: HashMap<PathBuf, FileChange>,
}
impl DiffTracker {
pub fn on_patch_begin(&mut self, changes: &HashMap<PathBuf, FileChange>) {
self.changes.extend(changes.clone());
}
pub fn on_patch_end(&mut self, success: bool) {
if !success {
self.changes.clear();
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum FileChange {
Add {
content: String,
},
Delete,
Update {
old_content: String,
new_content: String,
},
Rename {
new_path: PathBuf,
content: Option<String>,
},
}
#[derive(Clone, Debug)]
pub enum ToolEvent {
Begin(ToolEventBegin),
Success(ToolEventSuccess),
Failure(ToolEventFailure),
PatchApplyBegin(PatchApplyBeginEvent),
PatchApplyEnd(PatchApplyEndEvent),
}
#[derive(Clone, Debug)]
pub struct ToolEventBegin {
pub call_id: String,
pub tool_name: String,
pub turn_id: String,
}
#[derive(Clone, Debug)]
pub struct ToolEventSuccess {
pub call_id: String,
pub output: String,
}
#[derive(Clone, Debug)]
pub struct ToolEventFailure {
pub call_id: String,
pub error: String,
}
#[derive(Clone, Debug)]
pub struct PatchApplyBeginEvent {
pub call_id: String,
pub turn_id: String,
pub changes: HashMap<PathBuf, FileChange>,
pub auto_approved: bool,
}
#[derive(Clone, Debug)]
pub struct PatchApplyEndEvent {
pub call_id: String,
pub success: bool,
pub stdout: String,
pub stderr: String,
}
#[derive(Debug, thiserror::Error)]
pub enum ToolCallError {
#[error("Tool error: {0}")]
RespondToModel(String),
#[error("Internal error: {0}")]
Internal(#[from] anyhow::Error),
#[error("Tool rejected: {0}")]
Rejected(String),
#[error("Tool timed out after {0}ms")]
Timeout(u64),
}
impl ToolCallError {
pub fn respond(message: impl Into<String>) -> Self {
Self::RespondToModel(message.into())
}
}
#[async_trait]
pub trait ToolHandler: Send + Sync {
fn kind(&self) -> ToolKind;
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(
(self.kind(), payload),
(ToolKind::Function, ToolPayload::Function { .. })
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
| (ToolKind::Custom, ToolPayload::Custom { .. })
)
}
async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
false
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, ToolCallError>;
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolSpec {
Function(ResponsesApiTool),
Freeform(FreeformTool),
WebSearch {},
LocalShell {},
}
impl ToolSpec {
pub fn name(&self) -> &str {
match self {
ToolSpec::Function(tool) => &tool.name,
ToolSpec::Freeform(tool) => &tool.name,
ToolSpec::WebSearch {} => tools::WEB_SEARCH,
ToolSpec::LocalShell {} => "local_shell",
}
}
}
#[derive(Clone, Debug)]
pub struct ConfiguredToolSpec {
pub spec: ToolSpec,
pub supports_parallel_tool_calls: bool,
}
impl ConfiguredToolSpec {
pub fn new(spec: ToolSpec, supports_parallel: bool) -> Self {
Self {
spec,
supports_parallel_tool_calls: supports_parallel,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_output_simple() {
let output = ToolOutput::simple("Hello, world!");
assert!(output.is_success());
assert_eq!(output.content(), Some("Hello, world!"));
}
#[test]
fn test_tool_output_error() {
let output = ToolOutput::error("Something went wrong");
assert!(!output.is_success());
assert_eq!(output.content(), Some("Something went wrong"));
}
#[test]
fn test_sandbox_permissions_default() {
let perms = SandboxPermissions::default();
assert_eq!(perms, SandboxPermissions::UseDefault);
}
#[test]
fn test_turn_context_resolve_path_absolute() {
let ctx = TurnContext {
cwd: PathBuf::from("/workspace"),
turn_id: "test".to_string(),
sub_id: None,
shell_environment_policy: ShellEnvironmentPolicy::default(),
approval_policy: Constrained::allow_any(ApprovalPolicy::default()),
codex_linux_sandbox_exe: None,
sandbox_policy: Constrained::allow_any(Default::default()),
};
let resolved = ctx.resolve_path(Some("/absolute/path".to_string()));
assert_eq!(resolved, PathBuf::from("/absolute/path"));
}
#[test]
fn test_turn_context_resolve_path_relative() {
let ctx = TurnContext {
cwd: PathBuf::from("/workspace"),
turn_id: "test".to_string(),
sub_id: None,
shell_environment_policy: ShellEnvironmentPolicy::default(),
approval_policy: Constrained::allow_any(ApprovalPolicy::default()),
codex_linux_sandbox_exe: None,
sandbox_policy: Constrained::allow_any(Default::default()),
};
let resolved = ctx.resolve_path(Some("relative/path".to_string()));
assert_eq!(resolved, PathBuf::from("/workspace/relative/path"));
}
}