Skip to main content

engram/
llm_command.rs

1//! `CommandLlmClient` — the shell-out extensibility escape hatch.
2//!
3//! Spawns a user-supplied command per extraction call, writes a JSON request
4//! to its stdin, and reads a JSON response from its stdout. Use this to plug
5//! in any provider Engram does not ship natively — an internal corporate LLM
6//! gateway, an exotic model behind a custom RPC, a local inference binary,
7//! or a quick wrapper over any HTTP API with whatever auth scheme you need.
8//!
9//! # The contract
10//!
11//! Engram invokes your command through `sh -c <command>` (so `$HOME`,
12//! pipes, redirects, and environment variables all work). The command must:
13//!
14//! 1. Read exactly one JSON object from stdin, matching:
15//!    ```json
16//!    {"system": "...", "user": "...", "structured": true}
17//!    ```
18//!    `structured` is `true` when Engram wants JSON output (extraction,
19//!    consolidation), `false` when it wants plain text.
20//!
21//! 2. Write exactly one JSON value to stdout. Either:
22//!    - Plain: the JSON value that Engram should use directly.
23//!    - Enveloped: `{"content": <value>}` or `{"error": "message"}`.
24//!
25//! 3. Exit with code 0 on success, non-zero on failure. On non-zero exit,
26//!    stderr is surfaced to the caller.
27//!
28//! # Example wrapper (Python, ~15 lines)
29//!
30//! ```text
31//! #!/usr/bin/env python3
32//! # my-llm.py — wraps any Python LLM SDK for Engram.
33//! import json, sys
34//! from my_llm_sdk import chat  # your SDK
35//!
36//! req = json.loads(sys.stdin.read())
37//! resp = chat(system=req["system"], user=req["user"],
38//!             json_mode=req.get("structured", False))
39//! # Either write the raw response:
40//! sys.stdout.write(json.dumps(resp))
41//! # Or envelope it:
42//! # sys.stdout.write(json.dumps({"content": resp}))
43//! ```
44//!
45//! Then point Engram at it:
46//! ```text
47//! ENGRAM_LLM_PROVIDER=command \
48//! ENGRAM_LLM_COMMAND="python /path/to/my-llm.py" \
49//! engram serve
50//! ```
51//!
52//! # Security
53//!
54//! `CommandLlmClient` runs **arbitrary commands** as the Engram process user.
55//! Never expose it in a multi-tenant deployment where untrusted users can
56//! control `ENGRAM_LLM_COMMAND`. It is a local and single-tenant feature.
57
58use crate::llm::LlmClient;
59use crate::llm_util::extract_json_payload;
60use crate::store::MemoryError;
61use async_trait::async_trait;
62use serde::{Deserialize, Serialize};
63use std::process::Stdio;
64use tokio::io::AsyncWriteExt;
65use tokio::process::Command;
66
67/// A shell-out LLM client. Clone-cheap: stores only the command string and
68/// timeout configuration.
69#[derive(Clone, Debug)]
70pub struct CommandLlmClient {
71    command: String,
72    timeout_secs: u64,
73}
74
75impl CommandLlmClient {
76    /// Construct a client that will run `command` via `sh -c` on every call.
77    /// Default timeout is 120 seconds.
78    pub fn new(command: impl Into<String>) -> Self {
79        Self {
80            command: command.into(),
81            timeout_secs: 120,
82        }
83    }
84
85    /// Override the per-call timeout in seconds.
86    pub fn with_timeout(mut self, secs: u64) -> Self {
87        self.timeout_secs = secs;
88        self
89    }
90}
91
92#[derive(Serialize)]
93struct CommandRequest<'a> {
94    system: &'a str,
95    user: &'a str,
96    structured: bool,
97}
98
99/// Optional envelope wrapping the command's output. Commands MAY return a
100/// bare JSON value or this envelope. If both `content` and `error` are
101/// present, `error` wins.
102#[derive(Deserialize)]
103struct CommandEnvelope {
104    #[serde(default)]
105    content: Option<serde_json::Value>,
106    #[serde(default)]
107    error: Option<String>,
108}
109
110impl CommandLlmClient {
111    async fn call(
112        &self,
113        system: &str,
114        user: &str,
115        structured: bool,
116    ) -> Result<serde_json::Value, MemoryError> {
117        if self.command.trim().is_empty() {
118            return Err(MemoryError::Database(
119                "CommandLlmClient: command is empty".into(),
120            ));
121        }
122
123        let mut child = Command::new("sh")
124            .arg("-c")
125            .arg(&self.command)
126            .stdin(Stdio::piped())
127            .stdout(Stdio::piped())
128            .stderr(Stdio::piped())
129            .spawn()
130            .map_err(|e| MemoryError::Database(format!("command spawn failed: {e}")))?;
131
132        let request_json = serde_json::to_string(&CommandRequest {
133            system,
134            user,
135            structured,
136        })
137        .map_err(|e| MemoryError::Serialization(format!("CommandRequest serialize: {e}")))?;
138
139        if let Some(mut stdin) = child.stdin.take() {
140            stdin
141                .write_all(request_json.as_bytes())
142                .await
143                .map_err(|e| MemoryError::Database(format!("command stdin write: {e}")))?;
144            // stdin dropped here, signalling EOF to the child.
145        }
146
147        let timeout = std::time::Duration::from_secs(self.timeout_secs);
148        let output = tokio::time::timeout(timeout, child.wait_with_output())
149            .await
150            .map_err(|_| {
151                MemoryError::Database(format!("command timed out after {}s", self.timeout_secs))
152            })?
153            .map_err(|e| MemoryError::Database(format!("command wait failed: {e}")))?;
154
155        if !output.status.success() {
156            let stderr = String::from_utf8_lossy(&output.stderr);
157            let code = output
158                .status
159                .code()
160                .map(|c| c.to_string())
161                .unwrap_or_else(|| "signal".to_string());
162            return Err(MemoryError::Database(format!(
163                "command exited with code {code}: {}",
164                stderr.trim()
165            )));
166        }
167
168        let stdout = String::from_utf8_lossy(&output.stdout).to_string();
169        let payload = extract_json_payload(&stdout);
170
171        if payload.is_empty() {
172            return Err(MemoryError::Database(
173                "command produced empty stdout".into(),
174            ));
175        }
176
177        // First try to parse as an envelope and honour `error` / `content`.
178        if let Ok(envelope) = serde_json::from_str::<CommandEnvelope>(payload) {
179            if let Some(err) = envelope.error {
180                return Err(MemoryError::Database(format!("command error: {err}")));
181            }
182            if let Some(content) = envelope.content {
183                return Ok(content);
184            }
185        }
186
187        // Fall through: treat the whole stdout as the JSON value itself.
188        serde_json::from_str(payload).map_err(|e| {
189            MemoryError::Serialization(format!(
190                "command JSON parse: {e} (stdout head: {})",
191                payload.chars().take(200).collect::<String>()
192            ))
193        })
194    }
195}
196
197#[async_trait]
198impl LlmClient for CommandLlmClient {
199    async fn complete(&self, system: &str, user: &str) -> Result<String, MemoryError> {
200        let value = self.call(system, user, false).await?;
201        match value {
202            serde_json::Value::String(s) => Ok(s),
203            other => Ok(other.to_string()),
204        }
205    }
206
207    async fn structured_output(
208        &self,
209        system: &str,
210        user: &str,
211    ) -> Result<serde_json::Value, MemoryError> {
212        self.call(system, user, true).await
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[tokio::test]
221    async fn returns_raw_json_output() {
222        let client = CommandLlmClient::new(r#"cat > /dev/null; echo '{"facts":[]}'"#);
223        let result = client.structured_output("sys", "user").await.unwrap();
224        assert_eq!(result, serde_json::json!({"facts": []}));
225    }
226
227    #[tokio::test]
228    async fn returns_envelope_content() {
229        let client = CommandLlmClient::new(
230            r#"cat > /dev/null; echo '{"content":{"facts":[{"text":"hello"}]}}'"#,
231        );
232        let result = client.structured_output("sys", "user").await.unwrap();
233        assert_eq!(result, serde_json::json!({"facts": [{"text": "hello"}]}));
234    }
235
236    #[tokio::test]
237    async fn envelope_error_surfaces() {
238        let client =
239            CommandLlmClient::new(r#"cat > /dev/null; echo '{"error":"deliberate failure"}'"#);
240        let err = client
241            .structured_output("sys", "user")
242            .await
243            .expect_err("should error");
244        assert!(err.to_string().contains("deliberate failure"));
245    }
246
247    #[tokio::test]
248    async fn nonzero_exit_is_error() {
249        let client = CommandLlmClient::new(r#"cat > /dev/null; echo 'oops' >&2; exit 7"#);
250        let err = client
251            .structured_output("sys", "user")
252            .await
253            .expect_err("should error");
254        let msg = err.to_string();
255        assert!(msg.contains("code 7"), "expected exit code in error: {msg}");
256    }
257
258    #[tokio::test]
259    async fn empty_command_rejected() {
260        let client = CommandLlmClient::new("   ");
261        let err = client
262            .structured_output("sys", "user")
263            .await
264            .expect_err("should error");
265        assert!(err.to_string().contains("empty"));
266    }
267
268    #[tokio::test]
269    async fn complete_returns_string_text() {
270        let client = CommandLlmClient::new(r#"cat > /dev/null; echo '"hello there"'"#);
271        let result = client.complete("sys", "user").await.unwrap();
272        assert_eq!(result, "hello there");
273    }
274
275    #[tokio::test]
276    async fn command_sees_request_on_stdin() {
277        // Use `wc -c` to count the stdin bytes the child receives, then echo
278        // a JSON envelope back to verify bidirectional flow.
279        let client =
280            CommandLlmClient::new(r#"bytes=$(wc -c); echo "{\"content\":{\"bytes\":$bytes}}""#);
281        let result = client
282            .structured_output("system prompt", "user prompt")
283            .await
284            .unwrap();
285        // Request JSON is non-empty — exact byte count depends on serde output,
286        // but it should be > 30 bytes for the fields we pass.
287        let bytes = result["bytes"].as_u64().unwrap_or(0);
288        assert!(bytes > 30, "expected stdin bytes > 30, got {bytes}");
289    }
290}