Skip to main content

ai_agent/utils/
process_user_input.rs

1//! Process user input utilities - translates processUserInput.ts from TypeScript
2//!
3//! This module handles processing user input, including text prompts, bash commands,
4//! slash commands, and attachments.
5
6#![allow(dead_code)]
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::types::Message;
12
13/// Prompt input mode
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
15#[serde(rename_all = "lowercase")]
16pub enum PromptInputMode {
17    #[default]
18    Prompt,
19    Bash,
20    Print,
21    Continue,
22}
23
24/// Process user input context - combines ToolUseContext and LocalJSXCommandContext
25#[derive(Debug, Clone)]
26pub struct ProcessUserInputContext {
27    /// Session ID
28    pub session_id: String,
29    /// Current working directory
30    pub cwd: String,
31    /// Agent ID if set
32    pub agent_id: Option<String>,
33    /// Query tracking information
34    pub query_tracking: Option<QueryTracking>,
35    /// Context options
36    pub options: ProcessUserInputContextOptions,
37}
38
39/// Query tracking for analytics
40#[derive(Debug, Clone, Serialize, Deserialize)]
41#[serde(rename_all = "camelCase")]
42pub struct QueryTracking {
43    pub chain_id: String,
44    pub depth: u32,
45}
46
47/// Process user input context options
48#[derive(Debug, Clone)]
49pub struct ProcessUserInputContextOptions {
50    /// Available commands
51    pub commands: Vec<Value>,
52    /// Debug mode
53    pub debug: bool,
54    /// Available tools
55    pub tools: Vec<crate::types::ToolDefinition>,
56    /// Verbose mode
57    pub verbose: bool,
58    /// Main loop model
59    pub main_loop_model: Option<String>,
60    /// Thinking configuration
61    pub thinking_config: Option<crate::query_engine::ThinkingConfig>,
62    /// MCP clients
63    pub mcp_clients: Vec<Value>,
64    /// MCP resources
65    pub mcp_resources: std::collections::HashMap<String, Value>,
66    /// IDE installation status
67    pub ide_installation_status: Option<Value>,
68    /// Non-interactive session flag
69    pub is_non_interactive_session: bool,
70    /// Custom system prompt
71    pub custom_system_prompt: Option<String>,
72    /// Append system prompt
73    pub append_system_prompt: Option<String>,
74    /// Agent definitions
75    pub agent_definitions: AgentDefinitions,
76    /// Theme
77    pub theme: Option<String>,
78    /// Max budget in USD
79    pub max_budget_usd: Option<f64>,
80}
81
82impl Default for ProcessUserInputContext {
83    fn default() -> Self {
84        Self {
85            session_id: String::new(),
86            cwd: String::new(),
87            agent_id: None,
88            query_tracking: None,
89            options: ProcessUserInputContextOptions::default(),
90        }
91    }
92}
93
94impl Default for ProcessUserInputContextOptions {
95    fn default() -> Self {
96        Self {
97            commands: vec![],
98            debug: false,
99            tools: vec![],
100            verbose: false,
101            main_loop_model: None,
102            thinking_config: None,
103            mcp_clients: vec![],
104            mcp_resources: std::collections::HashMap::new(),
105            ide_installation_status: None,
106            is_non_interactive_session: false,
107            custom_system_prompt: None,
108            append_system_prompt: None,
109            agent_definitions: AgentDefinitions::default(),
110            theme: None,
111            max_budget_usd: None,
112        }
113    }
114}
115
116/// Agent definitions
117#[derive(Debug, Clone, Default, Serialize, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct AgentDefinitions {
120    pub active_agents: Vec<Value>,
121    pub all_agents: Vec<Value>,
122    pub allowed_agent_types: Option<Vec<String>>,
123}
124
125/// Effort value for the model
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(rename_all = "camelCase")]
128pub struct EffortValue {
129    pub effort: String,
130    pub reason: Option<String>,
131}
132
133/// Result of processing user input
134#[derive(Debug, Clone)]
135pub struct ProcessUserInputBaseResult {
136    /// Messages to be sent to the model
137    pub messages: Vec<Message>,
138    /// Whether a query should be made
139    pub should_query: bool,
140    /// Allowed tools (optional)
141    pub allowed_tools: Option<Vec<String>>,
142    /// Model to use (optional)
143    pub model: Option<String>,
144    /// Effort value (optional)
145    pub effort: Option<EffortValue>,
146    /// Output text for non-interactive mode
147    pub result_text: Option<String>,
148    /// Next input to prefilling (optional)
149    pub next_input: Option<String>,
150    /// Whether to submit next input
151    pub submit_next_input: Option<bool>,
152}
153
154impl Default for ProcessUserInputBaseResult {
155    fn default() -> Self {
156        Self {
157            messages: vec![],
158            should_query: true,
159            allowed_tools: None,
160            model: None,
161            effort: None,
162            result_text: None,
163            next_input: None,
164            submit_next_input: None,
165        }
166    }
167}
168
169/// Input for process_user_input function
170pub struct ProcessUserInputOptions {
171    /// Input string or content blocks
172    pub input: ProcessUserInput,
173    /// Input before expansion (for ultraplan keyword detection)
174    pub pre_expansion_input: Option<String>,
175    /// Input mode
176    pub mode: PromptInputMode,
177    /// Context for processing
178    pub context: ProcessUserInputContext,
179    /// Pasted contents from the user
180    pub pasted_contents: Option<std::collections::HashMap<u32, PastedContent>>,
181    /// IDE selection
182    pub ide_selection: Option<IdeSelection>,
183    /// Existing messages
184    pub messages: Option<Vec<Message>>,
185    /// Function to set user input while processing
186    pub set_user_input_on_processing: Option<Box<dyn Fn(Option<String>) + Send + Sync>>,
187    /// UUID for the prompt
188    pub uuid: Option<String>,
189    /// Whether input is already being processed
190    pub is_already_processing: Option<bool>,
191    /// Query source
192    pub query_source: Option<QuerySource>,
193    /// Function to check if tool can be used
194    pub can_use_tool: Option<crate::utils::hooks::CanUseToolFnJson>,
195    /// Skip slash command processing
196    pub skip_slash_commands: Option<bool>,
197    /// Bridge origin (for remote control)
198    pub bridge_origin: Option<bool>,
199    /// Whether this is a meta message (system-generated)
200    pub is_meta: Option<bool>,
201    /// Skip attachment processing
202    pub skip_attachments: Option<bool>,
203}
204
205impl Default for ProcessUserInputOptions {
206    fn default() -> Self {
207        Self {
208            input: ProcessUserInput::String(String::new()),
209            pre_expansion_input: None,
210            mode: PromptInputMode::Prompt,
211            context: ProcessUserInputContext::default(),
212            pasted_contents: None,
213            ide_selection: None,
214            messages: None,
215            set_user_input_on_processing: None,
216            uuid: None,
217            is_already_processing: None,
218            query_source: None,
219            can_use_tool: None,
220            skip_slash_commands: None,
221            bridge_origin: None,
222            is_meta: None,
223            skip_attachments: None,
224        }
225    }
226}
227
228/// User input - either string or content blocks
229#[derive(Clone)]
230pub enum ProcessUserInput {
231    String(String),
232    ContentBlocks(Vec<ContentBlockParam>),
233}
234
235impl std::fmt::Debug for ProcessUserInput {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        match self {
238            ProcessUserInput::String(s) => f.debug_tuple("String").field(s).finish(),
239            ProcessUserInput::ContentBlocks(blocks) => {
240                f.debug_tuple("ContentBlocks").field(blocks).finish()
241            }
242        }
243    }
244}
245
246/// Content block parameter (similar to Anthropic SDK's ContentBlockParam)
247#[derive(Debug, Clone, Serialize, Deserialize)]
248#[serde(rename_all = "camelCase")]
249pub enum ContentBlockParam {
250    /// Text content block
251    Text {
252        /// Text content
253        text: String,
254    },
255    /// Image content block
256    Image {
257        /// Image source
258        source: ImageSource,
259    },
260    /// Tool use content block
261    ToolUse {
262        /// Tool use ID
263        id: String,
264        /// Tool name
265        name: String,
266        /// Tool input
267        input: Value,
268    },
269    /// Tool result content block
270    ToolResult {
271        /// Tool use ID
272        tool_use_id: String,
273        /// Tool result content
274        content: Value,
275        /// Whether this is an error
276        #[serde(default, skip_serializing_if = "Option::is_none")]
277        is_error: Option<bool>,
278    },
279}
280
281/// Image source for content blocks
282#[derive(Debug, Clone, Serialize, Deserialize)]
283#[serde(rename_all = "camelCase")]
284pub struct ImageSource {
285    /// Image type (base64)
286    #[serde(rename = "type")]
287    pub source_type: String,
288    /// Media type (e.g., "image/png")
289    pub media_type: String,
290    /// Base64-encoded image data
291    pub data: String,
292}
293
294/// Pasted content from user
295#[derive(Debug, Clone)]
296pub struct PastedContent {
297    /// Unique ID
298    pub id: u32,
299    /// Content (base64-encoded)
300    pub content: String,
301    /// Media type
302    pub media_type: Option<String>,
303    /// Source path (optional)
304    pub source_path: Option<String>,
305    /// Dimensions (optional)
306    pub dimensions: Option<ImageDimensions>,
307}
308
309/// Image dimensions
310#[derive(Debug, Clone, Serialize, Deserialize)]
311#[serde(rename_all = "camelCase")]
312pub struct ImageDimensions {
313    pub width: u32,
314    pub height: u32,
315}
316
317/// IDE selection
318#[derive(Debug, Clone, Serialize, Deserialize)]
319#[serde(rename_all = "camelCase")]
320pub struct IdeSelection {
321    /// File path
322    pub file_path: String,
323    /// Selected text
324    pub selected_text: Option<String>,
325    /// Cursor position
326    pub cursor_position: Option<CursorPosition>,
327}
328
329/// Cursor position in IDE
330#[derive(Debug, Clone, Serialize, Deserialize)]
331#[serde(rename_all = "camelCase")]
332pub struct CursorPosition {
333    pub line: u32,
334    pub character: u32,
335}
336
337/// Query source enum
338#[derive(Debug, Clone, Serialize, Deserialize)]
339#[serde(rename_all = "snake_case")]
340pub enum QuerySource {
341    Prompt,
342    Continue,
343    SlashCommand,
344    BashCommand,
345    Attachments,
346    AutoAttach,
347    Resubmit,
348}
349
350/// Process user input - main entry point
351///
352/// # Arguments
353/// * `options` - Options for processing user input
354///
355/// # Returns
356/// A future that resolves to ProcessUserInputBaseResult
357pub async fn process_user_input(
358    options: ProcessUserInputOptions,
359) -> Result<ProcessUserInputBaseResult, String> {
360    let input_string = match &options.input {
361        ProcessUserInput::String(s) => Some(s.clone()),
362        ProcessUserInput::ContentBlocks(blocks) => blocks.iter().find_map(|b| {
363            if let ContentBlockParam::Text { text } = b {
364                Some(text.clone())
365            } else {
366                None
367            }
368        }),
369    };
370
371    // Set user input on processing if in prompt mode
372    if options.mode == PromptInputMode::Prompt
373        && input_string.is_some()
374        && options.is_meta != Some(true)
375    {
376        if let Some(ref callback) = options.set_user_input_on_processing {
377            callback(input_string.clone());
378        }
379    }
380
381    // Process the input - take ownership of needed fields
382    let input = options.input;
383    let mode = options.mode;
384    let context = options.context;
385    let pasted_contents = options.pasted_contents;
386    let uuid = options.uuid;
387    let is_meta = options.is_meta;
388    let skip_slash_commands = options.skip_slash_commands;
389    let bridge_origin = options.bridge_origin;
390
391    let result = process_user_input_base(
392        input,
393        mode,
394        context,
395        pasted_contents,
396        uuid,
397        is_meta,
398        skip_slash_commands,
399        bridge_origin,
400    )
401    .await?;
402
403    // Execute user prompt submit hooks (simplified stub)
404    // In the full implementation, this would execute hooks and potentially modify result
405
406    Ok(result)
407}
408
409/// Internal function to process user input
410async fn process_user_input_base(
411    input: ProcessUserInput,
412    mode: PromptInputMode,
413    _context: ProcessUserInputContext,
414    pasted_contents: Option<std::collections::HashMap<u32, PastedContent>>,
415    uuid: Option<String>,
416    is_meta: Option<bool>,
417    skip_slash_commands: Option<bool>,
418    bridge_origin: Option<bool>,
419) -> Result<ProcessUserInputBaseResult, String> {
420    let input_string = match &input {
421        ProcessUserInput::String(s) => Some(s.clone()),
422        ProcessUserInput::ContentBlocks(blocks) => blocks.iter().find_map(|b| {
423            if let ContentBlockParam::Text { text } = b {
424                Some(text.clone())
425            } else {
426                None
427            }
428        }),
429    };
430
431    let mut preceding_input_blocks: Vec<ContentBlockParam> = vec![];
432    let mut normalized_input = input.clone();
433
434    // Handle content blocks - extract text and preceding blocks
435    if let ProcessUserInput::ContentBlocks(blocks) = &input {
436        if !blocks.is_empty() {
437            let last_block = blocks.last().unwrap();
438            if let ContentBlockParam::Text { text } = last_block {
439                let text = text.clone();
440                preceding_input_blocks = blocks[..blocks.len() - 1].to_vec();
441                normalized_input = ProcessUserInput::String(text);
442            } else {
443                preceding_input_blocks = blocks.clone();
444            }
445        }
446    }
447
448    // Validate mode requires string input
449    if input_string.is_none() && mode != PromptInputMode::Prompt {
450        return Err(format!("Mode: {:?} requires a string input.", mode));
451    }
452
453    // Process pasted images
454    let image_content_blocks = process_pasted_images(pasted_contents.as_ref()).await;
455
456    // Check for bridge-safe slash command override
457    let effective_skip_slash = check_bridge_safe_slash_command(
458        bridge_origin,
459        input_string.as_deref(),
460        skip_slash_commands,
461    );
462
463    // Handle bash commands
464    if let Some(input) = input_string {
465        if mode == PromptInputMode::Bash {
466            // Process bash command (simplified)
467            return process_bash_command(input, preceding_input_blocks, vec![]);
468        }
469
470        // Handle slash commands
471        if !effective_skip_slash && input.starts_with('/') {
472            return process_slash_command(
473                input,
474                preceding_input_blocks,
475                image_content_blocks,
476                vec![],
477            );
478        }
479    }
480
481    // Regular user prompt
482    process_text_prompt(
483        normalized_input,
484        image_content_blocks,
485        vec![],
486        uuid,
487        None, // permission_mode
488        is_meta,
489    )
490}
491
492/// Check if slash commands should be skipped for bridge origin
493fn check_bridge_safe_slash_command(
494    bridge_origin: Option<bool>,
495    input_string: Option<&str>,
496    skip_slash_commands: Option<bool>,
497) -> bool {
498    if bridge_origin != Some(true) {
499        return skip_slash_commands.unwrap_or(false);
500    }
501
502    let input = match input_string {
503        Some(s) => s,
504        None => return skip_slash_commands.unwrap_or(false),
505    };
506
507    if !input.starts_with('/') {
508        return skip_slash_commands.unwrap_or(false);
509    }
510
511    // For bridge origin with slash command, we don't skip
512    false
513}
514
515/// Process pasted images
516async fn process_pasted_images(
517    pasted_contents: Option<&std::collections::HashMap<u32, PastedContent>>,
518) -> Vec<ContentBlockParam> {
519    if pasted_contents.is_none() {
520        return vec![];
521    }
522
523    let contents = pasted_contents.unwrap();
524    let mut image_blocks = vec![];
525
526    for (_, pasted) in contents.iter() {
527        let media_type = pasted.media_type.as_deref().unwrap_or("image/png");
528        image_blocks.push(ContentBlockParam::Image {
529            source: ImageSource {
530                source_type: "base64".to_string(),
531                media_type: media_type.to_string(),
532                data: pasted.content.clone(),
533            },
534        });
535    }
536
537    image_blocks
538}
539
540/// Process text prompt
541fn process_text_prompt(
542    input: ProcessUserInput,
543    _image_content_blocks: Vec<ContentBlockParam>,
544    _attachment_messages: Vec<Message>,
545    uuid: Option<String>,
546    _permission_mode: Option<crate::query_engine::PermissionMode>,
547    is_meta: Option<bool>,
548) -> Result<ProcessUserInputBaseResult, String> {
549    let content = match input {
550        ProcessUserInput::String(s) => {
551            if s.trim().is_empty() {
552                vec![]
553            } else {
554                vec![Value::String(s)]
555            }
556        }
557        ProcessUserInput::ContentBlocks(blocks) => blocks
558            .iter()
559            .map(|b| serde_json::to_value(b).unwrap_or(Value::Null))
560            .collect(),
561    };
562
563    let message = Message {
564        role: crate::types::MessageRole::User,
565        content: serde_json::json!({ "type": "text", "text": content }).to_string(),
566        attachments: None,
567        tool_call_id: None,
568        tool_calls: None,
569        is_error: None,
570    };
571
572    Ok(ProcessUserInputBaseResult {
573        messages: vec![message],
574        should_query: true,
575        ..Default::default()
576    })
577}
578
579/// Process bash command (simplified stub)
580fn process_bash_command(
581    _input: String,
582    _preceding_input_blocks: Vec<ContentBlockParam>,
583    _attachment_messages: Vec<Message>,
584) -> Result<ProcessUserInputBaseResult, String> {
585    // Simplified stub - full implementation would be in processBashCommand.tsx
586    Ok(ProcessUserInputBaseResult {
587        messages: vec![],
588        should_query: false,
589        allowed_tools: None,
590        model: None,
591        effort: None,
592        result_text: Some("Bash command processing not yet implemented".to_string()),
593        next_input: None,
594        submit_next_input: None,
595    })
596}
597
598/// Process slash command (simplified stub)
599fn process_slash_command(
600    _input: String,
601    _preceding_input_blocks: Vec<ContentBlockParam>,
602    _image_content_blocks: Vec<ContentBlockParam>,
603    _attachment_messages: Vec<Message>,
604) -> Result<ProcessUserInputBaseResult, String> {
605    // Simplified stub - full implementation would be in processSlashCommand.tsx
606    Ok(ProcessUserInputBaseResult {
607        messages: vec![],
608        should_query: false,
609        allowed_tools: None,
610        model: None,
611        effort: None,
612        result_text: Some("Slash command processing not yet implemented".to_string()),
613        next_input: None,
614        submit_next_input: None,
615    })
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[test]
623    fn test_process_user_input_default() {
624        let options = ProcessUserInputOptions::default();
625        assert!(matches!(options.input, ProcessUserInput::String(s) if s.is_empty()));
626        assert_eq!(options.mode, PromptInputMode::Prompt);
627    }
628
629    #[test]
630    fn test_process_text_prompt() {
631        let result = process_text_prompt(
632            ProcessUserInput::String("Hello".to_string()),
633            vec![],
634            vec![],
635            Some("test-uuid".to_string()),
636            None,
637            Some(true),
638        )
639        .unwrap();
640
641        assert!(result.should_query);
642        assert_eq!(result.messages.len(), 1);
643    }
644}