#![allow(dead_code)]
use crate::error::AgentError;
use crate::services::model_cost::TokenUsage;
use crate::types::*;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use uuid::Uuid;
pub type SDKMessage = Value;
pub struct QueryEngineConfig {
pub cwd: String,
pub tools: Vec<ToolDefinition>,
pub commands: Vec<Value>, pub mcp_clients: Vec<crate::mcp::McpConnection>,
pub agents: Vec<Value>, pub initial_messages: Option<Vec<Message>>,
pub read_file_cache: Option<FileStateCache>,
pub custom_system_prompt: Option<String>,
pub append_system_prompt: Option<String>,
pub user_specified_model: Option<String>,
pub fallback_model: Option<String>,
pub thinking_config: Option<ThinkingConfig>,
pub max_turns: Option<u32>,
pub max_budget_usd: Option<f64>,
pub task_budget: Option<TaskBudget>,
pub json_schema: Option<Value>,
pub verbose: bool,
pub replay_user_messages: bool,
pub include_partial_messages: bool,
pub abort_controller: Option<AbortController>,
pub orphaned_permission: Option<OrphanedPermission>,
}
impl Default for QueryEngineConfig {
fn default() -> Self {
Self {
cwd: String::new(),
tools: vec![],
commands: vec![],
mcp_clients: vec![],
agents: vec![],
initial_messages: None,
read_file_cache: None,
custom_system_prompt: None,
append_system_prompt: None,
user_specified_model: None,
fallback_model: None,
thinking_config: None,
max_turns: None,
max_budget_usd: None,
task_budget: None,
json_schema: None,
verbose: false,
replay_user_messages: false,
include_partial_messages: false,
abort_controller: None,
orphaned_permission: None,
}
}
}
#[derive(Debug, Clone)]
pub struct TaskBudget {
pub total: f64,
}
#[derive(Debug, Clone)]
pub struct ThinkingConfig {
pub thinking_type: ThinkingType,
}
#[derive(Debug, Clone)]
pub enum ThinkingType {
Adaptive,
Enabled,
Disabled,
}
impl Default for ThinkingConfig {
fn default() -> Self {
Self {
thinking_type: ThinkingType::Adaptive,
}
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum PermissionMode {
#[default]
Ask,
Allow,
Deny,
Bypass,
}
#[derive(Debug, Clone)]
pub struct SDKPermissionDenial {
pub tool_name: String,
pub tool_use_id: String,
pub tool_input: Value,
}
#[derive(Debug, Clone)]
pub struct SDKStatus {
pub status: String,
pub message: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AbortController {
aborted: Arc<Mutex<bool>>,
}
impl AbortController {
pub fn new() -> Self {
Self {
aborted: Arc::new(Mutex::new(false)),
}
}
pub fn abort(&self) {
*self.aborted.lock().unwrap() = true;
}
pub fn is_aborted(&self) -> bool {
*self.aborted.lock().unwrap()
}
}
impl Default for AbortController {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct FileStateCache {
pub cache: HashMap<String, Value>,
}
#[derive(Debug, Clone)]
pub struct OrphanedPermission {
pub tool_name: String,
pub tool_input: Value,
pub tool_use_id: String,
}
#[derive(Debug, Clone)]
pub struct ElicitationRequest {
pub tool_name: String,
pub message: String,
pub url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ElicitationResponse {
pub url: Option<String>,
pub selection: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SnipResult {
pub messages: Vec<SDKMessage>,
pub executed: bool,
}
pub struct ToolUseContext {
pub cwd: String,
pub session_id: String,
pub agent_id: Option<String>,
pub query_tracking: Option<QueryTracking>,
pub options: ToolUseContextOptions,
pub abort_controller: AbortController,
pub read_file_state: FileStateCache,
}
pub struct ToolUseContextOptions {
pub commands: Vec<Value>,
pub debug: bool,
pub tools: Vec<ToolDefinition>,
pub verbose: bool,
pub main_loop_model: Option<String>,
pub thinking_config: Option<ThinkingConfig>,
pub mcp_clients: Vec<crate::mcp::McpConnection>,
pub mcp_resources: HashMap<String, Value>,
pub ide_installation_status: Option<Value>,
pub is_non_interactive_session: bool,
pub custom_system_prompt: Option<String>,
pub append_system_prompt: Option<String>,
pub agent_definitions: AgentDefinitions,
pub theme: Option<String>,
pub max_budget_usd: Option<f64>,
}
impl Default for ToolUseContextOptions {
fn default() -> Self {
Self {
commands: vec![],
debug: false,
tools: vec![],
verbose: false,
main_loop_model: None,
thinking_config: None,
mcp_clients: vec![],
mcp_resources: HashMap::new(),
ide_installation_status: None,
is_non_interactive_session: false,
custom_system_prompt: None,
append_system_prompt: None,
agent_definitions: AgentDefinitions::default(),
theme: None,
max_budget_usd: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AgentDefinitions {
pub active_agents: Vec<Value>,
pub all_agents: Vec<Value>,
pub allowed_agent_types: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct QueryTracking {
pub chain_id: String,
pub depth: u32,
}
pub type CanUseToolFn = dyn Fn(
&ToolDefinition,
&Value,
&ToolUseContext,
&Option<Message>,
&str,
Option<bool>,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = PermissionDecision> + Send + Sync>>
+ Send
+ Sync;
pub type HandleElicitationFn = Box<
dyn Fn(
ElicitationRequest,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Option<ElicitationResponse>> + Send + Sync>,
> + Send
+ Sync,
>;
pub type SetSdkStatusFn = Box<dyn Fn(SDKStatus) + Send + Sync>;
pub type SnipReplayFn = Box<dyn Fn(&Message, &[Message]) -> Option<SnipResult> + Send + Sync>;
#[derive(Debug, Clone)]
pub enum PermissionDecision {
Allow,
Deny { reason: Option<String> },
Ask { expires_at: Option<u64> },
}
#[derive(Debug, Clone, Default)]
pub struct AppState {
pub tool_permission_context: ToolPermissionContext,
pub fast_mode: bool,
pub file_history: Value,
pub attribution: Value,
pub mcp: McpState,
pub effort_value: Option<f64>,
pub advisor_model: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ToolPermissionContext {
pub mode: PermissionMode,
pub always_allow_rules: AlwaysAllowRules,
pub additional_working_directories: HashMap<String, String>,
}
#[derive(Debug, Clone, Default)]
pub struct AlwaysAllowRules {
pub command: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct McpState {
pub tools: Vec<Value>,
pub clients: Vec<McpClient>,
}
#[derive(Debug, Clone)]
pub struct McpClient {
pub name: String,
pub client_type: String, }
#[derive(Debug, Clone, Default)]
pub struct NonNullableUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_creation_input_tokens: Option<u64>,
pub cache_read_input_tokens: Option<u64>,
}
impl From<TokenUsage> for NonNullableUsage {
fn from(usage: TokenUsage) -> Self {
Self {
input_tokens: usage.input_tokens as u64,
output_tokens: usage.output_tokens as u64,
cache_creation_input_tokens: Some(usage.prompt_cache_write_tokens as u64),
cache_read_input_tokens: Some(usage.prompt_cache_read_tokens as u64),
}
}
}
#[derive(Debug, Default)]
pub struct SubmitOptions {
pub uuid: Option<String>,
pub is_meta: Option<bool>,
}
pub async fn ask(config: AskConfig) -> Result<Vec<SDKMessage>, AgentError> {
let initial_messages: Option<Vec<Message>> = config.mutable_messages.map(|msgs| {
msgs.into_iter()
.map(|v| Message {
role: MessageRole::User,
content: v.to_string(),
attachments: None,
tool_call_id: None,
tool_calls: None,
is_error: None,
})
.collect()
});
let engine = QueryEngine::new(QueryEngineConfig {
cwd: config.cwd,
tools: config.tools,
commands: vec![],
mcp_clients: config.mcp_clients.unwrap_or_default(),
agents: config.agents.unwrap_or_default(),
initial_messages,
read_file_cache: None,
custom_system_prompt: config.custom_system_prompt,
append_system_prompt: config.append_system_prompt,
user_specified_model: config.user_specified_model,
fallback_model: config.fallback_model,
thinking_config: config.thinking_config,
max_turns: config.max_turns,
max_budget_usd: config.max_budget_usd,
task_budget: config.task_budget,
json_schema: config.json_schema,
verbose: config.verbose.unwrap_or(false),
replay_user_messages: config.replay_user_messages.unwrap_or(false),
include_partial_messages: config.include_partial_messages.unwrap_or(false),
abort_controller: config.abort_controller,
orphaned_permission: config.orphaned_permission,
});
let messages: Vec<SDKMessage> = engine
.get_messages()
.iter()
.map(|m| {
serde_json::json!({
"role": format!("{:?}", m.role),
"content": m.content,
})
})
.collect();
Ok(messages)
}
pub struct AskConfig {
pub prompt: String,
pub prompt_uuid: Option<String>,
pub is_meta: Option<bool>,
pub cwd: String,
pub tools: Vec<ToolDefinition>,
pub mcp_clients: Option<Vec<crate::mcp::McpConnection>>,
pub verbose: Option<bool>,
pub thinking_config: Option<ThinkingConfig>,
pub max_turns: Option<u32>,
pub max_budget_usd: Option<f64>,
pub task_budget: Option<TaskBudget>,
pub mutable_messages: Option<Vec<SDKMessage>>,
pub custom_system_prompt: Option<String>,
pub append_system_prompt: Option<String>,
pub user_specified_model: Option<String>,
pub fallback_model: Option<String>,
pub json_schema: Option<Value>,
pub abort_controller: Option<AbortController>,
pub replay_user_messages: Option<bool>,
pub include_partial_messages: Option<bool>,
pub agents: Option<Vec<Value>>,
pub orphaned_permission: Option<OrphanedPermission>,
}
impl Default for AskConfig {
fn default() -> Self {
Self {
prompt: String::new(),
prompt_uuid: None,
is_meta: None,
cwd: String::new(),
tools: vec![],
mcp_clients: None,
verbose: None,
thinking_config: None,
max_turns: None,
max_budget_usd: None,
task_budget: None,
mutable_messages: None,
custom_system_prompt: None,
append_system_prompt: None,
user_specified_model: None,
fallback_model: None,
json_schema: None,
abort_controller: None,
replay_user_messages: None,
include_partial_messages: None,
agents: None,
orphaned_permission: None,
}
}
}
pub struct QueryEngine {
config: QueryEngineConfig,
mutable_messages: Vec<Message>,
abort_controller: AbortController,
permission_denials: Vec<SDKPermissionDenial>,
total_usage: NonNullableUsage,
has_handled_orphaned_permission: bool,
read_file_state: FileStateCache,
discovered_skill_names: HashSet<String>,
loaded_nested_memory_paths: HashSet<String>,
}
impl QueryEngine {
pub fn new(config: QueryEngineConfig) -> Self {
Self {
config,
mutable_messages: vec![],
abort_controller: AbortController::new(),
permission_denials: vec![],
total_usage: NonNullableUsage::default(),
has_handled_orphaned_permission: false,
read_file_state: FileStateCache::default(),
discovered_skill_names: HashSet::new(),
loaded_nested_memory_paths: HashSet::new(),
}
}
pub fn interrupt(&mut self) {
self.abort_controller.abort();
}
pub fn get_messages(&self) -> &Vec<Message> {
&self.mutable_messages
}
pub fn get_read_file_state(&self) -> &FileStateCache {
&self.read_file_state
}
pub fn get_session_id(&self) -> String {
Uuid::new_v4().to_string()
}
pub fn set_model(&mut self, model: String) {
self.config.user_specified_model = Some(model);
}
}