Skip to main content

ai_memory/cli/commands/
expand.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! v0.7.0 #1443 — `ai-memory expand` CLI subcommand.
5//!
6//! Closes the three-surface-parity gap on query expansion. The MCP tool
7//! ([`crate::mcp::handle_expand_query`]) and the HTTP route
8//! (`POST /api/v1/expand_query`) landed previously; this module wires
9//! the CLI surface so an automation harness (notably the binary-faithful
10//! `benchmarks/longmemeval/harness.py`) can inject LLM query-expansion
11//! in-process via a `recall`-style one-shot — without standing up an MCP
12//! stdio server or an HTTP daemon per call.
13//!
14//! ## DRY contract
15//!
16//! No expansion logic lives here — this module is a clap arg-parser plus
17//! an output formatter. The term-generation primitive
18//! ([`crate::llm::OllamaClient::expand_query`] — whose sync form is a
19//! `block_on` of `expand_query_async`) is the single source of expanded
20//! terms for every surface, so the term set is byte-equal across MCP,
21//! HTTP, and CLI. MCP (`memory_expand_query`) and this CLI surface
22//! additionally share the envelope-shaping helper
23//! [`crate::mcp::handle_expand_query`] (`{original, expanded_terms}`); the
24//! HTTP route (`POST /api/v1/expand_query`) keeps its own async handler
25//! for the Axum runtime + H8 timeout + 503/502/400 status granularity but
26//! emits the same `{original, expanded_terms}` envelope (parity pinned by
27//! `tests/l07_3_chunk_d_http_surface.rs`; #1445).
28//!
29//! ## LLM resolution
30//!
31//! The LLM client is resolved through the same
32//! [`crate::daemon_runtime::build_llm_client`] ladder the daemon uses
33//! (CLI flag > `AI_MEMORY_LLM_*` env > `[llm]` section > legacy fields >
34//! compiled tier preset). An entirely Ollama-free configuration —
35//! `AI_MEMORY_LLM_BACKEND=openrouter` plus a key — drives expansion
36//! against a cloud backend, which is exactly the no-Ollama path the
37//! v0.7.0 LongMemEval reproduction exercised.
38
39use anyhow::Result;
40use clap::Args;
41use serde_json::{Value, json};
42
43use crate::cli::CliOutput;
44use crate::config::AppConfig;
45use crate::models::field_names;
46
47/// Exit code when no LLM backend is configured (503-equivalent — the
48/// expansion primitive is unreachable, not failing).
49pub const EXIT_NO_LLM: i32 = 2;
50
51/// Exit code when an LLM backend is configured but the expansion call
52/// itself failed (502-equivalent — upstream error).
53pub const EXIT_LLM_FAILED: i32 = 3;
54
55/// CLI args for `ai-memory expand`. Mirrors the MCP `memory_expand_query`
56/// `input_schema` shape (a single free-text `query`).
57#[derive(Args, Debug, Clone)]
58pub struct ExpandArgs {
59    /// Free-text query to expand into semantic reformulations.
60    #[arg(value_name = "QUERY")]
61    pub query: String,
62
63    /// Emit the raw JSON envelope
64    /// (`{query, expanded_terms, elapsed_ms, key_source}`) on stdout
65    /// instead of a human-readable summary. Built for harness
66    /// consumption.
67    #[arg(long)]
68    pub json: bool,
69}
70
71/// `ai-memory expand` dispatch entry. Resolves the LLM client through
72/// the daemon ladder, routes the query through the shared substrate
73/// primitive ([`crate::mcp::handle_expand_query`]), and emits the
74/// expanded terms — guaranteeing the term set is byte-equal with the
75/// MCP / HTTP surfaces.
76///
77/// Returns an exit code rather than propagating an error so the no-LLM
78/// and upstream-failure cases get stable, harness-detectable codes:
79/// - `0` — success, terms emitted.
80/// - [`EXIT_NO_LLM`] (`2`) — no LLM configured (503-equivalent).
81/// - [`EXIT_LLM_FAILED`] (`3`) — LLM configured but the call failed.
82///
83/// # Errors
84///
85/// Propagates only fatal I/O errors (writing to stdout/stderr) and
86/// `serde_json::to_string` serialisation failures. Every expansion
87/// outcome is mapped to an exit code and returned as `Ok(code)`.
88pub async fn cmd_expand(
89    args: &ExpandArgs,
90    app_config: &AppConfig,
91    out: &mut CliOutput<'_>,
92) -> Result<i32> {
93    let feature_tier = app_config.effective_tier(None);
94    let llm = crate::daemon_runtime::build_llm_client(feature_tier, app_config).await;
95    let key_source = app_config
96        .resolve_llm(None, None, None)
97        .api_key_source
98        .as_str()
99        .to_string();
100    run_with_llm(args, llm.as_ref(), &key_source, out)
101}
102
103/// Visible-for-test core. Production resolves the client via
104/// [`cmd_expand`]; the test suite injects a wiremock-backed
105/// [`crate::llm::OllamaClient`] (or `None` for the no-LLM path) so the
106/// exit-code contract can be pinned without a live LLM. `key_source` is
107/// the resolved API-key provenance label (e.g. `env`, `config`, `none`)
108/// surfaced in the envelope for harness observability.
109///
110/// # Errors
111///
112/// Propagates only fatal stdout/stderr I/O errors and JSON
113/// serialisation failures. See [`cmd_expand`] for the exit-code map.
114pub fn run_with_llm(
115    args: &ExpandArgs,
116    llm: Option<&crate::llm::OllamaClient>,
117    key_source: &str,
118    out: &mut CliOutput<'_>,
119) -> Result<i32> {
120    if llm.is_none() {
121        let msg = "query expansion requires a configured LLM backend \
122                   (set AI_MEMORY_LLM_BACKEND + key, or use smart/autonomous tier)";
123        if args.json {
124            writeln!(
125                out.stdout,
126                "{}",
127                serde_json::to_string(&json!({
128                    "query": args.query,
129                    "error": msg,
130                    (field_names::KEY_SOURCE): key_source,
131                }))?
132            )?;
133        } else {
134            writeln!(out.stderr, "expand: {msg}")?;
135        }
136        return Ok(EXIT_NO_LLM);
137    }
138
139    let params = json!({ "query": args.query });
140    let started = std::time::Instant::now();
141    let result = crate::mcp::handle_expand_query(llm, &params);
142    let elapsed_ms = u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
143
144    match result {
145        Ok(envelope) => {
146            let terms = envelope
147                .get(field_names::EXPANDED_TERMS)
148                .cloned()
149                .unwrap_or_else(|| json!([]));
150            if args.json {
151                writeln!(
152                    out.stdout,
153                    "{}",
154                    serde_json::to_string(&json!({
155                        "query": args.query,
156                        (field_names::EXPANDED_TERMS): terms,
157                        (field_names::ELAPSED_MS): elapsed_ms,
158                        (field_names::KEY_SOURCE): key_source,
159                    }))?
160                )?;
161            } else {
162                let term_strs: Vec<&str> = terms
163                    .as_array()
164                    .map_or_else(Vec::new, |a| a.iter().filter_map(Value::as_str).collect());
165                writeln!(
166                    out.stdout,
167                    "expand: {} term(s) (elapsed {elapsed_ms}ms, key_source={key_source})",
168                    term_strs.len(),
169                )?;
170                for t in &term_strs {
171                    writeln!(out.stdout, "  - {t}")?;
172                }
173            }
174            Ok(0)
175        }
176        Err(e) => {
177            if args.json {
178                writeln!(
179                    out.stdout,
180                    "{}",
181                    serde_json::to_string(&json!({
182                        "query": args.query,
183                        "error": e,
184                        (field_names::ELAPSED_MS): elapsed_ms,
185                        (field_names::KEY_SOURCE): key_source,
186                    }))?
187                )?;
188            } else {
189                writeln!(out.stderr, "expand: LLM call failed: {e}")?;
190            }
191            Ok(EXIT_LLM_FAILED)
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::cli::test_utils::TestEnv;
200    use crate::llm::OllamaClient;
201    use wiremock::matchers::{method, path};
202    use wiremock::{Mock, MockServer, ResponseTemplate};
203
204    fn args(query: &str, json: bool) -> ExpandArgs {
205        ExpandArgs {
206            query: query.to_string(),
207            json,
208        }
209    }
210
211    async fn mount_tags_ok(server: &MockServer) {
212        Mock::given(method("GET"))
213            .and(path("/api/tags"))
214            .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
215            .mount(server)
216            .await;
217    }
218
219    #[test]
220    fn no_llm_json_emits_error_envelope_and_exit_no_llm() {
221        let mut env = TestEnv::fresh();
222        let code = {
223            let mut out = env.output();
224            run_with_llm(&args("foo", true), None, "none", &mut out).expect("ok")
225        };
226        assert_eq!(code, EXIT_NO_LLM);
227        let parsed: Value = serde_json::from_str(env.stdout_str().trim()).expect("json");
228        assert_eq!(parsed["query"], "foo");
229        assert!(parsed["error"].as_str().unwrap().contains("LLM backend"));
230        assert_eq!(parsed[field_names::KEY_SOURCE], "none");
231        assert!(env.stderr_str().is_empty());
232    }
233
234    #[test]
235    fn no_llm_text_emits_stderr_and_exit_no_llm() {
236        let mut env = TestEnv::fresh();
237        let code = {
238            let mut out = env.output();
239            run_with_llm(&args("foo", false), None, "none", &mut out).expect("ok")
240        };
241        assert_eq!(code, EXIT_NO_LLM);
242        assert!(env.stdout_str().is_empty());
243        assert!(env.stderr_str().contains("expand:"));
244    }
245
246    #[tokio::test(flavor = "multi_thread")]
247    async fn success_json_emits_terms_envelope() {
248        let server = MockServer::start().await;
249        mount_tags_ok(&server).await;
250        Mock::given(method("POST"))
251            .and(path("/api/chat"))
252            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
253                "message": {"content": "alpha\nbeta\n"},
254            })))
255            .mount(&server)
256            .await;
257        let uri = server.uri();
258        let (stdout, code) = tokio::task::spawn_blocking(move || {
259            let mut env = TestEnv::fresh();
260            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
261            let code = {
262                let mut out = env.output();
263                run_with_llm(&args("nets", true), Some(&client), "env", &mut out).expect("ok")
264            };
265            (env.stdout_str().to_string(), code)
266        })
267        .await
268        .unwrap();
269        assert_eq!(code, 0);
270        let parsed: Value = serde_json::from_str(stdout.trim()).expect("json");
271        assert_eq!(parsed["query"], "nets");
272        let terms = parsed[field_names::EXPANDED_TERMS].as_array().unwrap();
273        assert_eq!(terms.len(), 2);
274        assert_eq!(parsed[field_names::KEY_SOURCE], "env");
275        assert!(parsed[field_names::ELAPSED_MS].is_u64());
276    }
277
278    #[tokio::test(flavor = "multi_thread")]
279    async fn success_text_lists_terms() {
280        let server = MockServer::start().await;
281        mount_tags_ok(&server).await;
282        Mock::given(method("POST"))
283            .and(path("/api/chat"))
284            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
285                "message": {"content": "one\ntwo\nthree\n"},
286            })))
287            .mount(&server)
288            .await;
289        let uri = server.uri();
290        let stdout = tokio::task::spawn_blocking(move || {
291            let mut env = TestEnv::fresh();
292            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
293            {
294                let mut out = env.output();
295                let code =
296                    run_with_llm(&args("q", false), Some(&client), "config", &mut out).expect("ok");
297                assert_eq!(code, 0);
298            }
299            env.stdout_str().to_string()
300        })
301        .await
302        .unwrap();
303        assert!(stdout.contains("3 term(s)"));
304        assert!(stdout.contains("- one"));
305        assert!(stdout.contains("- three"));
306        assert!(stdout.contains("key_source=config"));
307    }
308
309    #[tokio::test(flavor = "multi_thread")]
310    async fn llm_error_json_returns_exit_llm_failed() {
311        let server = MockServer::start().await;
312        mount_tags_ok(&server).await;
313        Mock::given(method("POST"))
314            .and(path("/api/chat"))
315            .respond_with(ResponseTemplate::new(500).set_body_string("boom"))
316            .mount(&server)
317            .await;
318        let uri = server.uri();
319        let (stdout, code) = tokio::task::spawn_blocking(move || {
320            let mut env = TestEnv::fresh();
321            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
322            let code = {
323                let mut out = env.output();
324                run_with_llm(&args("q", true), Some(&client), "env", &mut out).expect("ok")
325            };
326            (env.stdout_str().to_string(), code)
327        })
328        .await
329        .unwrap();
330        assert_eq!(code, EXIT_LLM_FAILED);
331        let parsed: Value = serde_json::from_str(stdout.trim()).expect("json");
332        assert!(parsed["error"].is_string());
333    }
334
335    #[tokio::test(flavor = "multi_thread")]
336    async fn llm_error_text_returns_exit_llm_failed() {
337        let server = MockServer::start().await;
338        mount_tags_ok(&server).await;
339        Mock::given(method("POST"))
340            .and(path("/api/chat"))
341            .respond_with(ResponseTemplate::new(500).set_body_string("boom"))
342            .mount(&server)
343            .await;
344        let uri = server.uri();
345        let (stderr, code) = tokio::task::spawn_blocking(move || {
346            let mut env = TestEnv::fresh();
347            let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
348            let code = {
349                let mut out = env.output();
350                run_with_llm(&args("q", false), Some(&client), "env", &mut out).expect("ok")
351            };
352            (env.stderr_str().to_string(), code)
353        })
354        .await
355        .unwrap();
356        assert_eq!(code, EXIT_LLM_FAILED);
357        assert!(stderr.contains("LLM call failed"));
358    }
359}