1use anyhow::Result;
40use clap::Args;
41use serde_json::{Value, json};
42
43use crate::cli::CliOutput;
44use crate::config::AppConfig;
45use crate::models::field_names;
46
47pub const EXIT_NO_LLM: i32 = 2;
50
51pub const EXIT_LLM_FAILED: i32 = 3;
54
55#[derive(Args, Debug, Clone)]
58pub struct ExpandArgs {
59 #[arg(value_name = "QUERY")]
61 pub query: String,
62
63 #[arg(long)]
68 pub json: bool,
69}
70
71pub async fn cmd_expand(
89 args: &ExpandArgs,
90 app_config: &AppConfig,
91 out: &mut CliOutput<'_>,
92) -> Result<i32> {
93 let feature_tier = app_config.effective_tier(None);
94 let llm = crate::daemon_runtime::build_llm_client(feature_tier, app_config).await;
95 let key_source = app_config
96 .resolve_llm(None, None, None)
97 .api_key_source
98 .as_str()
99 .to_string();
100 run_with_llm(args, llm.as_ref(), &key_source, out)
101}
102
103pub fn run_with_llm(
115 args: &ExpandArgs,
116 llm: Option<&crate::llm::OllamaClient>,
117 key_source: &str,
118 out: &mut CliOutput<'_>,
119) -> Result<i32> {
120 if llm.is_none() {
121 let msg = "query expansion requires a configured LLM backend \
122 (set AI_MEMORY_LLM_BACKEND + key, or use smart/autonomous tier)";
123 if args.json {
124 writeln!(
125 out.stdout,
126 "{}",
127 serde_json::to_string(&json!({
128 "query": args.query,
129 "error": msg,
130 (field_names::KEY_SOURCE): key_source,
131 }))?
132 )?;
133 } else {
134 writeln!(out.stderr, "expand: {msg}")?;
135 }
136 return Ok(EXIT_NO_LLM);
137 }
138
139 let params = json!({ "query": args.query });
140 let started = std::time::Instant::now();
141 let result = crate::mcp::handle_expand_query(llm, ¶ms);
142 let elapsed_ms = u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
143
144 match result {
145 Ok(envelope) => {
146 let terms = envelope
147 .get(field_names::EXPANDED_TERMS)
148 .cloned()
149 .unwrap_or_else(|| json!([]));
150 if args.json {
151 writeln!(
152 out.stdout,
153 "{}",
154 serde_json::to_string(&json!({
155 "query": args.query,
156 (field_names::EXPANDED_TERMS): terms,
157 (field_names::ELAPSED_MS): elapsed_ms,
158 (field_names::KEY_SOURCE): key_source,
159 }))?
160 )?;
161 } else {
162 let term_strs: Vec<&str> = terms
163 .as_array()
164 .map_or_else(Vec::new, |a| a.iter().filter_map(Value::as_str).collect());
165 writeln!(
166 out.stdout,
167 "expand: {} term(s) (elapsed {elapsed_ms}ms, key_source={key_source})",
168 term_strs.len(),
169 )?;
170 for t in &term_strs {
171 writeln!(out.stdout, " - {t}")?;
172 }
173 }
174 Ok(0)
175 }
176 Err(e) => {
177 if args.json {
178 writeln!(
179 out.stdout,
180 "{}",
181 serde_json::to_string(&json!({
182 "query": args.query,
183 "error": e,
184 (field_names::ELAPSED_MS): elapsed_ms,
185 (field_names::KEY_SOURCE): key_source,
186 }))?
187 )?;
188 } else {
189 writeln!(out.stderr, "expand: LLM call failed: {e}")?;
190 }
191 Ok(EXIT_LLM_FAILED)
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::cli::test_utils::TestEnv;
200 use crate::llm::OllamaClient;
201 use wiremock::matchers::{method, path};
202 use wiremock::{Mock, MockServer, ResponseTemplate};
203
204 fn args(query: &str, json: bool) -> ExpandArgs {
205 ExpandArgs {
206 query: query.to_string(),
207 json,
208 }
209 }
210
211 async fn mount_tags_ok(server: &MockServer) {
212 Mock::given(method("GET"))
213 .and(path("/api/tags"))
214 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
215 .mount(server)
216 .await;
217 }
218
219 #[test]
220 fn no_llm_json_emits_error_envelope_and_exit_no_llm() {
221 let mut env = TestEnv::fresh();
222 let code = {
223 let mut out = env.output();
224 run_with_llm(&args("foo", true), None, "none", &mut out).expect("ok")
225 };
226 assert_eq!(code, EXIT_NO_LLM);
227 let parsed: Value = serde_json::from_str(env.stdout_str().trim()).expect("json");
228 assert_eq!(parsed["query"], "foo");
229 assert!(parsed["error"].as_str().unwrap().contains("LLM backend"));
230 assert_eq!(parsed[field_names::KEY_SOURCE], "none");
231 assert!(env.stderr_str().is_empty());
232 }
233
234 #[test]
235 fn no_llm_text_emits_stderr_and_exit_no_llm() {
236 let mut env = TestEnv::fresh();
237 let code = {
238 let mut out = env.output();
239 run_with_llm(&args("foo", false), None, "none", &mut out).expect("ok")
240 };
241 assert_eq!(code, EXIT_NO_LLM);
242 assert!(env.stdout_str().is_empty());
243 assert!(env.stderr_str().contains("expand:"));
244 }
245
246 #[tokio::test(flavor = "multi_thread")]
247 async fn success_json_emits_terms_envelope() {
248 let server = MockServer::start().await;
249 mount_tags_ok(&server).await;
250 Mock::given(method("POST"))
251 .and(path("/api/chat"))
252 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
253 "message": {"content": "alpha\nbeta\n"},
254 })))
255 .mount(&server)
256 .await;
257 let uri = server.uri();
258 let (stdout, code) = tokio::task::spawn_blocking(move || {
259 let mut env = TestEnv::fresh();
260 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
261 let code = {
262 let mut out = env.output();
263 run_with_llm(&args("nets", true), Some(&client), "env", &mut out).expect("ok")
264 };
265 (env.stdout_str().to_string(), code)
266 })
267 .await
268 .unwrap();
269 assert_eq!(code, 0);
270 let parsed: Value = serde_json::from_str(stdout.trim()).expect("json");
271 assert_eq!(parsed["query"], "nets");
272 let terms = parsed[field_names::EXPANDED_TERMS].as_array().unwrap();
273 assert_eq!(terms.len(), 2);
274 assert_eq!(parsed[field_names::KEY_SOURCE], "env");
275 assert!(parsed[field_names::ELAPSED_MS].is_u64());
276 }
277
278 #[tokio::test(flavor = "multi_thread")]
279 async fn success_text_lists_terms() {
280 let server = MockServer::start().await;
281 mount_tags_ok(&server).await;
282 Mock::given(method("POST"))
283 .and(path("/api/chat"))
284 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
285 "message": {"content": "one\ntwo\nthree\n"},
286 })))
287 .mount(&server)
288 .await;
289 let uri = server.uri();
290 let stdout = tokio::task::spawn_blocking(move || {
291 let mut env = TestEnv::fresh();
292 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
293 {
294 let mut out = env.output();
295 let code =
296 run_with_llm(&args("q", false), Some(&client), "config", &mut out).expect("ok");
297 assert_eq!(code, 0);
298 }
299 env.stdout_str().to_string()
300 })
301 .await
302 .unwrap();
303 assert!(stdout.contains("3 term(s)"));
304 assert!(stdout.contains("- one"));
305 assert!(stdout.contains("- three"));
306 assert!(stdout.contains("key_source=config"));
307 }
308
309 #[tokio::test(flavor = "multi_thread")]
310 async fn llm_error_json_returns_exit_llm_failed() {
311 let server = MockServer::start().await;
312 mount_tags_ok(&server).await;
313 Mock::given(method("POST"))
314 .and(path("/api/chat"))
315 .respond_with(ResponseTemplate::new(500).set_body_string("boom"))
316 .mount(&server)
317 .await;
318 let uri = server.uri();
319 let (stdout, code) = tokio::task::spawn_blocking(move || {
320 let mut env = TestEnv::fresh();
321 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
322 let code = {
323 let mut out = env.output();
324 run_with_llm(&args("q", true), Some(&client), "env", &mut out).expect("ok")
325 };
326 (env.stdout_str().to_string(), code)
327 })
328 .await
329 .unwrap();
330 assert_eq!(code, EXIT_LLM_FAILED);
331 let parsed: Value = serde_json::from_str(stdout.trim()).expect("json");
332 assert!(parsed["error"].is_string());
333 }
334
335 #[tokio::test(flavor = "multi_thread")]
336 async fn llm_error_text_returns_exit_llm_failed() {
337 let server = MockServer::start().await;
338 mount_tags_ok(&server).await;
339 Mock::given(method("POST"))
340 .and(path("/api/chat"))
341 .respond_with(ResponseTemplate::new(500).set_body_string("boom"))
342 .mount(&server)
343 .await;
344 let uri = server.uri();
345 let (stderr, code) = tokio::task::spawn_blocking(move || {
346 let mut env = TestEnv::fresh();
347 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
348 let code = {
349 let mut out = env.output();
350 run_with_llm(&args("q", false), Some(&client), "env", &mut out).expect("ok")
351 };
352 (env.stderr_str().to_string(), code)
353 })
354 .await
355 .unwrap();
356 assert_eq!(code, EXIT_LLM_FAILED);
357 assert!(stderr.contains("LLM call failed"));
358 }
359}