#![allow(dead_code)]
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::types::Message;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PromptInputMode {
#[default]
Prompt,
Bash,
Print,
Continue,
}
#[derive(Debug, Clone)]
pub struct ProcessUserInputContext {
pub session_id: String,
pub cwd: String,
pub agent_id: Option<String>,
pub query_tracking: Option<QueryTracking>,
pub options: ProcessUserInputContextOptions,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct QueryTracking {
pub chain_id: String,
pub depth: u32,
}
#[derive(Debug, Clone)]
pub struct ProcessUserInputContextOptions {
pub commands: Vec<Value>,
pub debug: bool,
pub tools: Vec<crate::types::ToolDefinition>,
pub verbose: bool,
pub main_loop_model: Option<String>,
pub thinking_config: Option<crate::query_engine::ThinkingConfig>,
pub mcp_clients: Vec<Value>,
pub mcp_resources: std::collections::HashMap<String, Value>,
pub ide_installation_status: Option<Value>,
pub is_non_interactive_session: bool,
pub custom_system_prompt: Option<String>,
pub append_system_prompt: Option<String>,
pub agent_definitions: AgentDefinitions,
pub theme: Option<String>,
pub max_budget_usd: Option<f64>,
}
impl Default for ProcessUserInputContext {
fn default() -> Self {
Self {
session_id: String::new(),
cwd: String::new(),
agent_id: None,
query_tracking: None,
options: ProcessUserInputContextOptions::default(),
}
}
}
impl Default for ProcessUserInputContextOptions {
fn default() -> Self {
Self {
commands: vec![],
debug: false,
tools: vec![],
verbose: false,
main_loop_model: None,
thinking_config: None,
mcp_clients: vec![],
mcp_resources: std::collections::HashMap::new(),
ide_installation_status: None,
is_non_interactive_session: false,
custom_system_prompt: None,
append_system_prompt: None,
agent_definitions: AgentDefinitions::default(),
theme: None,
max_budget_usd: None,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AgentDefinitions {
pub active_agents: Vec<Value>,
pub all_agents: Vec<Value>,
pub allowed_agent_types: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EffortValue {
pub effort: String,
pub reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ProcessUserInputBaseResult {
pub messages: Vec<Message>,
pub should_query: bool,
pub allowed_tools: Option<Vec<String>>,
pub model: Option<String>,
pub effort: Option<EffortValue>,
pub result_text: Option<String>,
pub next_input: Option<String>,
pub submit_next_input: Option<bool>,
}
impl Default for ProcessUserInputBaseResult {
fn default() -> Self {
Self {
messages: vec![],
should_query: true,
allowed_tools: None,
model: None,
effort: None,
result_text: None,
next_input: None,
submit_next_input: None,
}
}
}
pub struct ProcessUserInputOptions {
pub input: ProcessUserInput,
pub pre_expansion_input: Option<String>,
pub mode: PromptInputMode,
pub context: ProcessUserInputContext,
pub pasted_contents: Option<std::collections::HashMap<u32, PastedContent>>,
pub ide_selection: Option<IdeSelection>,
pub messages: Option<Vec<Message>>,
pub set_user_input_on_processing: Option<Box<dyn Fn(Option<String>) + Send + Sync>>,
pub uuid: Option<String>,
pub is_already_processing: Option<bool>,
pub query_source: Option<QuerySource>,
pub can_use_tool: Option<crate::utils::hooks::CanUseToolFnJson>,
pub skip_slash_commands: Option<bool>,
pub bridge_origin: Option<bool>,
pub is_meta: Option<bool>,
pub skip_attachments: Option<bool>,
}
impl Default for ProcessUserInputOptions {
fn default() -> Self {
Self {
input: ProcessUserInput::String(String::new()),
pre_expansion_input: None,
mode: PromptInputMode::Prompt,
context: ProcessUserInputContext::default(),
pasted_contents: None,
ide_selection: None,
messages: None,
set_user_input_on_processing: None,
uuid: None,
is_already_processing: None,
query_source: None,
can_use_tool: None,
skip_slash_commands: None,
bridge_origin: None,
is_meta: None,
skip_attachments: None,
}
}
}
#[derive(Clone)]
pub enum ProcessUserInput {
String(String),
ContentBlocks(Vec<ContentBlockParam>),
}
impl std::fmt::Debug for ProcessUserInput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProcessUserInput::String(s) => f.debug_tuple("String").field(s).finish(),
ProcessUserInput::ContentBlocks(blocks) => {
f.debug_tuple("ContentBlocks").field(blocks).finish()
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum ContentBlockParam {
Text {
text: String,
},
Image {
source: ImageSource,
},
ToolUse {
id: String,
name: String,
input: Value,
},
ToolResult {
tool_use_id: String,
content: Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ImageSource {
#[serde(rename = "type")]
pub source_type: String,
pub media_type: String,
pub data: String,
}
#[derive(Debug, Clone)]
pub struct PastedContent {
pub id: u32,
pub content: String,
pub media_type: Option<String>,
pub source_path: Option<String>,
pub dimensions: Option<ImageDimensions>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ImageDimensions {
pub width: u32,
pub height: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct IdeSelection {
pub file_path: String,
pub selected_text: Option<String>,
pub cursor_position: Option<CursorPosition>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CursorPosition {
pub line: u32,
pub character: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum QuerySource {
Prompt,
Continue,
SlashCommand,
BashCommand,
Attachments,
AutoAttach,
Resubmit,
}
pub async fn process_user_input(
options: ProcessUserInputOptions,
) -> Result<ProcessUserInputBaseResult, String> {
let input_string = match &options.input {
ProcessUserInput::String(s) => Some(s.clone()),
ProcessUserInput::ContentBlocks(blocks) => blocks.iter().find_map(|b| {
if let ContentBlockParam::Text { text } = b {
Some(text.clone())
} else {
None
}
}),
};
if options.mode == PromptInputMode::Prompt
&& input_string.is_some()
&& options.is_meta != Some(true)
{
if let Some(ref callback) = options.set_user_input_on_processing {
callback(input_string.clone());
}
}
let input = options.input;
let mode = options.mode;
let context = options.context;
let pasted_contents = options.pasted_contents;
let uuid = options.uuid;
let is_meta = options.is_meta;
let skip_slash_commands = options.skip_slash_commands;
let bridge_origin = options.bridge_origin;
let result = process_user_input_base(
input,
mode,
context,
pasted_contents,
uuid,
is_meta,
skip_slash_commands,
bridge_origin,
)
.await?;
Ok(result)
}
async fn process_user_input_base(
input: ProcessUserInput,
mode: PromptInputMode,
_context: ProcessUserInputContext,
pasted_contents: Option<std::collections::HashMap<u32, PastedContent>>,
uuid: Option<String>,
is_meta: Option<bool>,
skip_slash_commands: Option<bool>,
bridge_origin: Option<bool>,
) -> Result<ProcessUserInputBaseResult, String> {
let input_string = match &input {
ProcessUserInput::String(s) => Some(s.clone()),
ProcessUserInput::ContentBlocks(blocks) => blocks.iter().find_map(|b| {
if let ContentBlockParam::Text { text } = b {
Some(text.clone())
} else {
None
}
}),
};
let mut preceding_input_blocks: Vec<ContentBlockParam> = vec![];
let mut normalized_input = input.clone();
if let ProcessUserInput::ContentBlocks(blocks) = &input {
if !blocks.is_empty() {
let last_block = blocks.last().unwrap();
if let ContentBlockParam::Text { text } = last_block {
let text = text.clone();
preceding_input_blocks = blocks[..blocks.len() - 1].to_vec();
normalized_input = ProcessUserInput::String(text);
} else {
preceding_input_blocks = blocks.clone();
}
}
}
if input_string.is_none() && mode != PromptInputMode::Prompt {
return Err(format!("Mode: {:?} requires a string input.", mode));
}
let image_content_blocks = process_pasted_images(pasted_contents.as_ref()).await;
let effective_skip_slash = check_bridge_safe_slash_command(
bridge_origin,
input_string.as_deref(),
skip_slash_commands,
);
if let Some(input) = input_string {
if mode == PromptInputMode::Bash {
return process_bash_command(input, preceding_input_blocks, vec![]);
}
if !effective_skip_slash && input.starts_with('/') {
return process_slash_command(
input,
preceding_input_blocks,
image_content_blocks,
vec![],
);
}
}
process_text_prompt(
normalized_input,
image_content_blocks,
vec![],
uuid,
None, is_meta,
)
}
fn check_bridge_safe_slash_command(
bridge_origin: Option<bool>,
input_string: Option<&str>,
skip_slash_commands: Option<bool>,
) -> bool {
if bridge_origin != Some(true) {
return skip_slash_commands.unwrap_or(false);
}
let input = match input_string {
Some(s) => s,
None => return skip_slash_commands.unwrap_or(false),
};
if !input.starts_with('/') {
return skip_slash_commands.unwrap_or(false);
}
false
}
async fn process_pasted_images(
pasted_contents: Option<&std::collections::HashMap<u32, PastedContent>>,
) -> Vec<ContentBlockParam> {
if pasted_contents.is_none() {
return vec![];
}
let contents = pasted_contents.unwrap();
let mut image_blocks = vec![];
for (_, pasted) in contents.iter() {
let media_type = pasted.media_type.as_deref().unwrap_or("image/png");
image_blocks.push(ContentBlockParam::Image {
source: ImageSource {
source_type: "base64".to_string(),
media_type: media_type.to_string(),
data: pasted.content.clone(),
},
});
}
image_blocks
}
fn process_text_prompt(
input: ProcessUserInput,
_image_content_blocks: Vec<ContentBlockParam>,
_attachment_messages: Vec<Message>,
uuid: Option<String>,
_permission_mode: Option<crate::query_engine::PermissionMode>,
is_meta: Option<bool>,
) -> Result<ProcessUserInputBaseResult, String> {
let content = match input {
ProcessUserInput::String(s) => {
if s.trim().is_empty() {
vec![]
} else {
vec![Value::String(s)]
}
}
ProcessUserInput::ContentBlocks(blocks) => blocks
.iter()
.map(|b| serde_json::to_value(b).unwrap_or(Value::Null))
.collect(),
};
let message = Message {
role: crate::types::MessageRole::User,
content: serde_json::json!({ "type": "text", "text": content }).to_string(),
attachments: None,
tool_call_id: None,
tool_calls: None,
is_error: None,
};
Ok(ProcessUserInputBaseResult {
messages: vec![message],
should_query: true,
..Default::default()
})
}
fn process_bash_command(
_input: String,
_preceding_input_blocks: Vec<ContentBlockParam>,
_attachment_messages: Vec<Message>,
) -> Result<ProcessUserInputBaseResult, String> {
Ok(ProcessUserInputBaseResult {
messages: vec![],
should_query: false,
allowed_tools: None,
model: None,
effort: None,
result_text: Some("Bash command processing not yet implemented".to_string()),
next_input: None,
submit_next_input: None,
})
}
fn process_slash_command(
_input: String,
_preceding_input_blocks: Vec<ContentBlockParam>,
_image_content_blocks: Vec<ContentBlockParam>,
_attachment_messages: Vec<Message>,
) -> Result<ProcessUserInputBaseResult, String> {
Ok(ProcessUserInputBaseResult {
messages: vec![],
should_query: false,
allowed_tools: None,
model: None,
effort: None,
result_text: Some("Slash command processing not yet implemented".to_string()),
next_input: None,
submit_next_input: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_process_user_input_default() {
let options = ProcessUserInputOptions::default();
assert!(matches!(options.input, ProcessUserInput::String(s) if s.is_empty()));
assert_eq!(options.mode, PromptInputMode::Prompt);
}
#[test]
fn test_process_text_prompt() {
let result = process_text_prompt(
ProcessUserInput::String("Hello".to_string()),
vec![],
vec![],
Some("test-uuid".to_string()),
None,
Some(true),
)
.unwrap();
assert!(result.should_query);
assert_eq!(result.messages.len(), 1);
}
}