Skip to main content

omni_dev/cli/ai/
chat.rs

1//! Interactive AI chat command.
2
3use std::io::{self, Write};
4
5use anyhow::Result;
6use clap::Parser;
7use crossterm::{
8    event::{self, Event, KeyCode, KeyModifiers},
9    terminal::{disable_raw_mode, enable_raw_mode},
10};
11
12/// Interactive AI chat session.
13#[derive(Parser)]
14pub struct ChatCommand {
15    /// AI model to use (overrides environment configuration).
16    #[arg(long)]
17    pub model: Option<String>,
18}
19
20impl ChatCommand {
21    /// Executes the chat command.
22    pub async fn execute(self) -> Result<()> {
23        let ai_info = crate::utils::preflight::check_ai_credentials(self.model.as_deref())?;
24        eprintln!(
25            "Connected to {} (model: {})",
26            ai_info.provider, ai_info.model
27        );
28        eprintln!("Enter to send, Shift+Enter for newline, Ctrl+D to exit.\n");
29
30        let client = crate::claude::create_default_claude_client(self.model, None)?;
31
32        chat_loop(&client).await
33    }
34}
35
36/// Sends a single user message to the configured AI and returns the response.
37///
38/// Shared between the MCP `ai_chat` tool and any non-interactive CLI callers.
39/// The function performs the same preflight credential check as the CLI chat
40/// loop — on missing credentials it returns the preflight error verbatim so
41/// MCP tool callers see the same diagnostic message the CLI would print.
42///
43/// `model` selects the AI model; `None` uses the environment default.
44/// `system_prompt` defaults to `"You are a helpful assistant."` (matching the
45/// CLI's default) when `None`.
46pub async fn run_chat(
47    message: &str,
48    model: Option<String>,
49    system_prompt: Option<String>,
50) -> Result<String> {
51    crate::utils::preflight::check_ai_credentials(model.as_deref())?;
52    let client = crate::claude::create_default_claude_client(model, None)?;
53    let system = system_prompt
54        .as_deref()
55        .unwrap_or("You are a helpful assistant.");
56    client.send_message(system, message).await
57}
58
59async fn chat_loop(client: &crate::claude::client::ClaudeClient) -> Result<()> {
60    let system_prompt = "You are a helpful assistant.";
61
62    loop {
63        let input = match read_user_input() {
64            Ok(Some(text)) => text,
65            Ok(None) => {
66                eprintln!("\nGoodbye!");
67                break;
68            }
69            Err(e) => {
70                eprintln!("\nInput error: {e}");
71                break;
72            }
73        };
74
75        let trimmed = input.trim();
76        if trimmed.is_empty() {
77            continue;
78        }
79
80        let response = client.send_message(system_prompt, trimmed).await?;
81        println!("{response}\n");
82    }
83
84    Ok(())
85}
86
87/// Guard that disables raw mode on drop.
88struct RawModeGuard;
89
90impl Drop for RawModeGuard {
91    fn drop(&mut self) {
92        let _ = disable_raw_mode();
93    }
94}
95
96/// Reads multiline user input with "> " prompt.
97///
98/// Returns `Ok(Some(text))` on Enter, `Ok(None)` on Ctrl+D/Ctrl+C.
99fn read_user_input() -> Result<Option<String>> {
100    eprint!("> ");
101    io::stderr().flush()?;
102
103    enable_raw_mode()?;
104    let _guard = RawModeGuard;
105
106    let mut buffer = String::new();
107
108    loop {
109        if let Event::Key(key_event) = event::read()? {
110            match key_event.code {
111                KeyCode::Enter => {
112                    if key_event.modifiers.contains(KeyModifiers::SHIFT) {
113                        buffer.push('\n');
114                        eprint!("\r\n... ");
115                        io::stderr().flush()?;
116                    } else {
117                        eprint!("\r\n");
118                        io::stderr().flush()?;
119                        return Ok(Some(buffer));
120                    }
121                }
122                KeyCode::Char('d') if key_event.modifiers.contains(KeyModifiers::CONTROL) => {
123                    if buffer.is_empty() {
124                        return Ok(None);
125                    }
126                    eprint!("\r\n");
127                    io::stderr().flush()?;
128                    return Ok(Some(buffer));
129                }
130                KeyCode::Char('c') if key_event.modifiers.contains(KeyModifiers::CONTROL) => {
131                    return Ok(None);
132                }
133                KeyCode::Char(c) => {
134                    buffer.push(c);
135                    eprint!("{c}");
136                    io::stderr().flush()?;
137                }
138                KeyCode::Backspace if buffer.pop().is_some() => {
139                    eprint!("\x08 \x08");
140                    io::stderr().flush()?;
141                }
142                _ => {}
143            }
144        }
145    }
146}
147
148#[cfg(test)]
149#[allow(clippy::unwrap_used, clippy::expect_used)]
150mod tests {
151    use super::*;
152
153    /// Env-isolation lock — tests in this module must serialise because they
154    /// mutate `HOME` and every provider env var the preflight check reads.
155    static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
156
157    const KEYS: &[&str] = &[
158        "USE_OPENAI",
159        "USE_OLLAMA",
160        "CLAUDE_CODE_USE_BEDROCK",
161        "CLAUDE_API_KEY",
162        "ANTHROPIC_API_KEY",
163        "ANTHROPIC_AUTH_TOKEN",
164        "ANTHROPIC_BEDROCK_BASE_URL",
165        "OPENAI_API_KEY",
166        "OPENAI_AUTH_TOKEN",
167        "OLLAMA_MODEL",
168        "OLLAMA_BASE_URL",
169        "ANTHROPIC_MODEL",
170    ];
171
172    fn snapshot_env() -> Vec<(&'static str, Option<String>)> {
173        let mut v: Vec<(&'static str, Option<String>)> =
174            KEYS.iter().map(|k| (*k, std::env::var(k).ok())).collect();
175        v.push(("HOME", std::env::var("HOME").ok()));
176        v
177    }
178
179    fn restore_env(snap: Vec<(&'static str, Option<String>)>) {
180        for (k, v) in snap {
181            match v {
182                Some(val) => std::env::set_var(k, val),
183                None => std::env::remove_var(k),
184            }
185        }
186    }
187
188    fn isolate_empty_home() -> tempfile::TempDir {
189        let dir = {
190            std::fs::create_dir_all("tmp").ok();
191            tempfile::TempDir::new_in("tmp").unwrap()
192        };
193        std::env::set_var("HOME", dir.path());
194        for k in KEYS {
195            std::env::remove_var(k);
196        }
197        dir
198    }
199
200    #[allow(clippy::await_holding_lock)]
201    #[tokio::test]
202    async fn run_chat_returns_error_when_credentials_missing() {
203        let _guard = ENV_LOCK
204            .lock()
205            .unwrap_or_else(std::sync::PoisonError::into_inner);
206        let snap = snapshot_env();
207        let _home = isolate_empty_home();
208
209        let err = run_chat("hello", None, None).await.unwrap_err();
210        let msg = format!("{err}");
211        assert!(
212            msg.contains("API key not found") || msg.contains("not found"),
213            "expected credential error, got: {msg}"
214        );
215
216        restore_env(snap);
217    }
218
219    #[allow(clippy::await_holding_lock)]
220    #[tokio::test]
221    async fn run_chat_bubbles_up_credential_error_with_custom_system_prompt() {
222        let _guard = ENV_LOCK
223            .lock()
224            .unwrap_or_else(std::sync::PoisonError::into_inner);
225        let snap = snapshot_env();
226        let _home = isolate_empty_home();
227
228        // Custom system prompt should not bypass the preflight check.
229        let err = run_chat("hello", None, Some("be terse".to_string()))
230            .await
231            .unwrap_err();
232        assert!(format!("{err}").contains("not found"));
233
234        restore_env(snap);
235    }
236
237    #[allow(clippy::await_holding_lock)]
238    #[tokio::test]
239    async fn run_chat_propagates_model_override_through_preflight() {
240        let _guard = ENV_LOCK
241            .lock()
242            .unwrap_or_else(std::sync::PoisonError::into_inner);
243        let snap = snapshot_env();
244        let _home = isolate_empty_home();
245
246        // With explicit model override, the same credential check must still run.
247        let err = run_chat("hello", Some("claude-sonnet-4-6".to_string()), None)
248            .await
249            .unwrap_err();
250        assert!(format!("{err}").contains("not found"));
251
252        restore_env(snap);
253    }
254
255    /// Exercises the post-preflight code path (client construction and
256    /// `send_message`) without requiring real AI credentials. Routes through
257    /// Ollama mode, which skips the credential check, and points the client
258    /// at a wiremock server that returns a canned OpenAI-compatible response.
259    #[allow(clippy::await_holding_lock)]
260    #[tokio::test]
261    async fn run_chat_happy_path_via_mocked_ollama_returns_response_text() {
262        let _guard = ENV_LOCK
263            .lock()
264            .unwrap_or_else(std::sync::PoisonError::into_inner);
265        let snap = snapshot_env();
266        let _home = isolate_empty_home();
267
268        let server = wiremock::MockServer::start().await;
269        wiremock::Mock::given(wiremock::matchers::method("POST"))
270            .and(wiremock::matchers::path("/v1/chat/completions"))
271            .respond_with(
272                wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
273                    "id": "test",
274                    "object": "chat.completion",
275                    "choices": [{
276                        "index": 0,
277                        "message": {"role": "assistant", "content": "canned-response"},
278                        "finish_reason": "stop"
279                    }]
280                })),
281            )
282            .mount(&server)
283            .await;
284
285        std::env::set_var("USE_OLLAMA", "true");
286        std::env::set_var("OLLAMA_MODEL", "llama2");
287        std::env::set_var("OLLAMA_BASE_URL", server.uri());
288
289        let out = run_chat("hello", None, Some("be terse".to_string()))
290            .await
291            .unwrap();
292        assert_eq!(out, "canned-response");
293
294        restore_env(snap);
295    }
296
297    /// As above but with `system_prompt = None`, exercising the
298    /// `.unwrap_or("You are a helpful assistant.")` default branch.
299    #[allow(clippy::await_holding_lock)]
300    #[tokio::test]
301    async fn run_chat_default_system_prompt_path_via_mocked_ollama() {
302        let _guard = ENV_LOCK
303            .lock()
304            .unwrap_or_else(std::sync::PoisonError::into_inner);
305        let snap = snapshot_env();
306        let _home = isolate_empty_home();
307
308        let server = wiremock::MockServer::start().await;
309        wiremock::Mock::given(wiremock::matchers::method("POST"))
310            .and(wiremock::matchers::path("/v1/chat/completions"))
311            .respond_with(
312                wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
313                    "id": "test",
314                    "object": "chat.completion",
315                    "choices": [{
316                        "index": 0,
317                        "message": {"role": "assistant", "content": "ok"},
318                        "finish_reason": "stop"
319                    }]
320                })),
321            )
322            .mount(&server)
323            .await;
324
325        std::env::set_var("USE_OLLAMA", "true");
326        std::env::set_var("OLLAMA_MODEL", "llama2");
327        std::env::set_var("OLLAMA_BASE_URL", server.uri());
328
329        let out = run_chat("hello", None, None).await.unwrap();
330        assert_eq!(out, "ok");
331
332        restore_env(snap);
333    }
334}