1use std::io::{self, Write};
4
5use anyhow::Result;
6use clap::Parser;
7use crossterm::{
8 event::{self, Event, KeyCode, KeyModifiers},
9 terminal::{disable_raw_mode, enable_raw_mode},
10};
11
12#[derive(Parser)]
14pub struct ChatCommand {
15 #[arg(long)]
17 pub model: Option<String>,
18}
19
20impl ChatCommand {
21 pub async fn execute(self) -> Result<()> {
23 let ai_info = crate::utils::preflight::check_ai_credentials(self.model.as_deref())?;
24 eprintln!(
25 "Connected to {} (model: {})",
26 ai_info.provider, ai_info.model
27 );
28 eprintln!("Enter to send, Shift+Enter for newline, Ctrl+D to exit.\n");
29
30 let client = crate::claude::create_default_claude_client(self.model, None)?;
31
32 chat_loop(&client).await
33 }
34}
35
36pub async fn run_chat(
47 message: &str,
48 model: Option<String>,
49 system_prompt: Option<String>,
50) -> Result<String> {
51 crate::utils::preflight::check_ai_credentials(model.as_deref())?;
52 let client = crate::claude::create_default_claude_client(model, None)?;
53 let system = system_prompt
54 .as_deref()
55 .unwrap_or("You are a helpful assistant.");
56 client.send_message(system, message).await
57}
58
59async fn chat_loop(client: &crate::claude::client::ClaudeClient) -> Result<()> {
60 let system_prompt = "You are a helpful assistant.";
61
62 loop {
63 let input = match read_user_input() {
64 Ok(Some(text)) => text,
65 Ok(None) => {
66 eprintln!("\nGoodbye!");
67 break;
68 }
69 Err(e) => {
70 eprintln!("\nInput error: {e}");
71 break;
72 }
73 };
74
75 let trimmed = input.trim();
76 if trimmed.is_empty() {
77 continue;
78 }
79
80 let response = client.send_message(system_prompt, trimmed).await?;
81 println!("{response}\n");
82 }
83
84 Ok(())
85}
86
87struct RawModeGuard;
89
90impl Drop for RawModeGuard {
91 fn drop(&mut self) {
92 let _ = disable_raw_mode();
93 }
94}
95
96fn read_user_input() -> Result<Option<String>> {
100 eprint!("> ");
101 io::stderr().flush()?;
102
103 enable_raw_mode()?;
104 let _guard = RawModeGuard;
105
106 let mut buffer = String::new();
107
108 loop {
109 if let Event::Key(key_event) = event::read()? {
110 match key_event.code {
111 KeyCode::Enter => {
112 if key_event.modifiers.contains(KeyModifiers::SHIFT) {
113 buffer.push('\n');
114 eprint!("\r\n... ");
115 io::stderr().flush()?;
116 } else {
117 eprint!("\r\n");
118 io::stderr().flush()?;
119 return Ok(Some(buffer));
120 }
121 }
122 KeyCode::Char('d') if key_event.modifiers.contains(KeyModifiers::CONTROL) => {
123 if buffer.is_empty() {
124 return Ok(None);
125 }
126 eprint!("\r\n");
127 io::stderr().flush()?;
128 return Ok(Some(buffer));
129 }
130 KeyCode::Char('c') if key_event.modifiers.contains(KeyModifiers::CONTROL) => {
131 return Ok(None);
132 }
133 KeyCode::Char(c) => {
134 buffer.push(c);
135 eprint!("{c}");
136 io::stderr().flush()?;
137 }
138 KeyCode::Backspace if buffer.pop().is_some() => {
139 eprint!("\x08 \x08");
140 io::stderr().flush()?;
141 }
142 _ => {}
143 }
144 }
145 }
146}
147
148#[cfg(test)]
149#[allow(clippy::unwrap_used, clippy::expect_used)]
150mod tests {
151 use super::*;
152
153 static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
156
157 const KEYS: &[&str] = &[
158 "USE_OPENAI",
159 "USE_OLLAMA",
160 "CLAUDE_CODE_USE_BEDROCK",
161 "CLAUDE_API_KEY",
162 "ANTHROPIC_API_KEY",
163 "ANTHROPIC_AUTH_TOKEN",
164 "ANTHROPIC_BEDROCK_BASE_URL",
165 "OPENAI_API_KEY",
166 "OPENAI_AUTH_TOKEN",
167 "OLLAMA_MODEL",
168 "OLLAMA_BASE_URL",
169 "ANTHROPIC_MODEL",
170 ];
171
172 fn snapshot_env() -> Vec<(&'static str, Option<String>)> {
173 let mut v: Vec<(&'static str, Option<String>)> =
174 KEYS.iter().map(|k| (*k, std::env::var(k).ok())).collect();
175 v.push(("HOME", std::env::var("HOME").ok()));
176 v
177 }
178
179 fn restore_env(snap: Vec<(&'static str, Option<String>)>) {
180 for (k, v) in snap {
181 match v {
182 Some(val) => std::env::set_var(k, val),
183 None => std::env::remove_var(k),
184 }
185 }
186 }
187
188 fn isolate_empty_home() -> tempfile::TempDir {
189 let dir = {
190 std::fs::create_dir_all("tmp").ok();
191 tempfile::TempDir::new_in("tmp").unwrap()
192 };
193 std::env::set_var("HOME", dir.path());
194 for k in KEYS {
195 std::env::remove_var(k);
196 }
197 dir
198 }
199
200 #[allow(clippy::await_holding_lock)]
201 #[tokio::test]
202 async fn run_chat_returns_error_when_credentials_missing() {
203 let _guard = ENV_LOCK
204 .lock()
205 .unwrap_or_else(std::sync::PoisonError::into_inner);
206 let snap = snapshot_env();
207 let _home = isolate_empty_home();
208
209 let err = run_chat("hello", None, None).await.unwrap_err();
210 let msg = format!("{err}");
211 assert!(
212 msg.contains("API key not found") || msg.contains("not found"),
213 "expected credential error, got: {msg}"
214 );
215
216 restore_env(snap);
217 }
218
219 #[allow(clippy::await_holding_lock)]
220 #[tokio::test]
221 async fn run_chat_bubbles_up_credential_error_with_custom_system_prompt() {
222 let _guard = ENV_LOCK
223 .lock()
224 .unwrap_or_else(std::sync::PoisonError::into_inner);
225 let snap = snapshot_env();
226 let _home = isolate_empty_home();
227
228 let err = run_chat("hello", None, Some("be terse".to_string()))
230 .await
231 .unwrap_err();
232 assert!(format!("{err}").contains("not found"));
233
234 restore_env(snap);
235 }
236
237 #[allow(clippy::await_holding_lock)]
238 #[tokio::test]
239 async fn run_chat_propagates_model_override_through_preflight() {
240 let _guard = ENV_LOCK
241 .lock()
242 .unwrap_or_else(std::sync::PoisonError::into_inner);
243 let snap = snapshot_env();
244 let _home = isolate_empty_home();
245
246 let err = run_chat("hello", Some("claude-sonnet-4-6".to_string()), None)
248 .await
249 .unwrap_err();
250 assert!(format!("{err}").contains("not found"));
251
252 restore_env(snap);
253 }
254
255 #[allow(clippy::await_holding_lock)]
260 #[tokio::test]
261 async fn run_chat_happy_path_via_mocked_ollama_returns_response_text() {
262 let _guard = ENV_LOCK
263 .lock()
264 .unwrap_or_else(std::sync::PoisonError::into_inner);
265 let snap = snapshot_env();
266 let _home = isolate_empty_home();
267
268 let server = wiremock::MockServer::start().await;
269 wiremock::Mock::given(wiremock::matchers::method("POST"))
270 .and(wiremock::matchers::path("/v1/chat/completions"))
271 .respond_with(
272 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
273 "id": "test",
274 "object": "chat.completion",
275 "choices": [{
276 "index": 0,
277 "message": {"role": "assistant", "content": "canned-response"},
278 "finish_reason": "stop"
279 }]
280 })),
281 )
282 .mount(&server)
283 .await;
284
285 std::env::set_var("USE_OLLAMA", "true");
286 std::env::set_var("OLLAMA_MODEL", "llama2");
287 std::env::set_var("OLLAMA_BASE_URL", server.uri());
288
289 let out = run_chat("hello", None, Some("be terse".to_string()))
290 .await
291 .unwrap();
292 assert_eq!(out, "canned-response");
293
294 restore_env(snap);
295 }
296
297 #[allow(clippy::await_holding_lock)]
300 #[tokio::test]
301 async fn run_chat_default_system_prompt_path_via_mocked_ollama() {
302 let _guard = ENV_LOCK
303 .lock()
304 .unwrap_or_else(std::sync::PoisonError::into_inner);
305 let snap = snapshot_env();
306 let _home = isolate_empty_home();
307
308 let server = wiremock::MockServer::start().await;
309 wiremock::Mock::given(wiremock::matchers::method("POST"))
310 .and(wiremock::matchers::path("/v1/chat/completions"))
311 .respond_with(
312 wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
313 "id": "test",
314 "object": "chat.completion",
315 "choices": [{
316 "index": 0,
317 "message": {"role": "assistant", "content": "ok"},
318 "finish_reason": "stop"
319 }]
320 })),
321 )
322 .mount(&server)
323 .await;
324
325 std::env::set_var("USE_OLLAMA", "true");
326 std::env::set_var("OLLAMA_MODEL", "llama2");
327 std::env::set_var("OLLAMA_BASE_URL", server.uri());
328
329 let out = run_chat("hello", None, None).await.unwrap();
330 assert_eq!(out, "ok");
331
332 restore_env(snap);
333 }
334}