1use crate::llm::LlmClient;
59use crate::llm_util::extract_json_payload;
60use crate::store::MemoryError;
61use async_trait::async_trait;
62use serde::{Deserialize, Serialize};
63use std::process::Stdio;
64use tokio::io::AsyncWriteExt;
65use tokio::process::Command;
66
67#[derive(Clone, Debug)]
70pub struct CommandLlmClient {
71 command: String,
72 timeout_secs: u64,
73}
74
75impl CommandLlmClient {
76 pub fn new(command: impl Into<String>) -> Self {
79 Self {
80 command: command.into(),
81 timeout_secs: 120,
82 }
83 }
84
85 pub fn with_timeout(mut self, secs: u64) -> Self {
87 self.timeout_secs = secs;
88 self
89 }
90}
91
92#[derive(Serialize)]
93struct CommandRequest<'a> {
94 system: &'a str,
95 user: &'a str,
96 structured: bool,
97}
98
99#[derive(Deserialize)]
103struct CommandEnvelope {
104 #[serde(default)]
105 content: Option<serde_json::Value>,
106 #[serde(default)]
107 error: Option<String>,
108}
109
110impl CommandLlmClient {
111 async fn call(
112 &self,
113 system: &str,
114 user: &str,
115 structured: bool,
116 ) -> Result<serde_json::Value, MemoryError> {
117 if self.command.trim().is_empty() {
118 return Err(MemoryError::Database(
119 "CommandLlmClient: command is empty".into(),
120 ));
121 }
122
123 let mut child = Command::new("sh")
124 .arg("-c")
125 .arg(&self.command)
126 .stdin(Stdio::piped())
127 .stdout(Stdio::piped())
128 .stderr(Stdio::piped())
129 .spawn()
130 .map_err(|e| MemoryError::Database(format!("command spawn failed: {e}")))?;
131
132 let request_json = serde_json::to_string(&CommandRequest {
133 system,
134 user,
135 structured,
136 })
137 .map_err(|e| MemoryError::Serialization(format!("CommandRequest serialize: {e}")))?;
138
139 if let Some(mut stdin) = child.stdin.take() {
140 stdin
141 .write_all(request_json.as_bytes())
142 .await
143 .map_err(|e| MemoryError::Database(format!("command stdin write: {e}")))?;
144 }
146
147 let timeout = std::time::Duration::from_secs(self.timeout_secs);
148 let output = tokio::time::timeout(timeout, child.wait_with_output())
149 .await
150 .map_err(|_| {
151 MemoryError::Database(format!("command timed out after {}s", self.timeout_secs))
152 })?
153 .map_err(|e| MemoryError::Database(format!("command wait failed: {e}")))?;
154
155 if !output.status.success() {
156 let stderr = String::from_utf8_lossy(&output.stderr);
157 let code = output
158 .status
159 .code()
160 .map(|c| c.to_string())
161 .unwrap_or_else(|| "signal".to_string());
162 return Err(MemoryError::Database(format!(
163 "command exited with code {code}: {}",
164 stderr.trim()
165 )));
166 }
167
168 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
169 let payload = extract_json_payload(&stdout);
170
171 if payload.is_empty() {
172 return Err(MemoryError::Database(
173 "command produced empty stdout".into(),
174 ));
175 }
176
177 if let Ok(envelope) = serde_json::from_str::<CommandEnvelope>(payload) {
179 if let Some(err) = envelope.error {
180 return Err(MemoryError::Database(format!("command error: {err}")));
181 }
182 if let Some(content) = envelope.content {
183 return Ok(content);
184 }
185 }
186
187 serde_json::from_str(payload).map_err(|e| {
189 MemoryError::Serialization(format!(
190 "command JSON parse: {e} (stdout head: {})",
191 payload.chars().take(200).collect::<String>()
192 ))
193 })
194 }
195}
196
197#[async_trait]
198impl LlmClient for CommandLlmClient {
199 async fn complete(&self, system: &str, user: &str) -> Result<String, MemoryError> {
200 let value = self.call(system, user, false).await?;
201 match value {
202 serde_json::Value::String(s) => Ok(s),
203 other => Ok(other.to_string()),
204 }
205 }
206
207 async fn structured_output(
208 &self,
209 system: &str,
210 user: &str,
211 ) -> Result<serde_json::Value, MemoryError> {
212 self.call(system, user, true).await
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[tokio::test]
221 async fn returns_raw_json_output() {
222 let client = CommandLlmClient::new(r#"cat > /dev/null; echo '{"facts":[]}'"#);
223 let result = client.structured_output("sys", "user").await.unwrap();
224 assert_eq!(result, serde_json::json!({"facts": []}));
225 }
226
227 #[tokio::test]
228 async fn returns_envelope_content() {
229 let client = CommandLlmClient::new(
230 r#"cat > /dev/null; echo '{"content":{"facts":[{"text":"hello"}]}}'"#,
231 );
232 let result = client.structured_output("sys", "user").await.unwrap();
233 assert_eq!(result, serde_json::json!({"facts": [{"text": "hello"}]}));
234 }
235
236 #[tokio::test]
237 async fn envelope_error_surfaces() {
238 let client =
239 CommandLlmClient::new(r#"cat > /dev/null; echo '{"error":"deliberate failure"}'"#);
240 let err = client
241 .structured_output("sys", "user")
242 .await
243 .expect_err("should error");
244 assert!(err.to_string().contains("deliberate failure"));
245 }
246
247 #[tokio::test]
248 async fn nonzero_exit_is_error() {
249 let client = CommandLlmClient::new(r#"cat > /dev/null; echo 'oops' >&2; exit 7"#);
250 let err = client
251 .structured_output("sys", "user")
252 .await
253 .expect_err("should error");
254 let msg = err.to_string();
255 assert!(msg.contains("code 7"), "expected exit code in error: {msg}");
256 }
257
258 #[tokio::test]
259 async fn empty_command_rejected() {
260 let client = CommandLlmClient::new(" ");
261 let err = client
262 .structured_output("sys", "user")
263 .await
264 .expect_err("should error");
265 assert!(err.to_string().contains("empty"));
266 }
267
268 #[tokio::test]
269 async fn complete_returns_string_text() {
270 let client = CommandLlmClient::new(r#"cat > /dev/null; echo '"hello there"'"#);
271 let result = client.complete("sys", "user").await.unwrap();
272 assert_eq!(result, "hello there");
273 }
274
275 #[tokio::test]
276 async fn command_sees_request_on_stdin() {
277 let client =
280 CommandLlmClient::new(r#"bytes=$(wc -c); echo "{\"content\":{\"bytes\":$bytes}}""#);
281 let result = client
282 .structured_output("system prompt", "user prompt")
283 .await
284 .unwrap();
285 let bytes = result["bytes"].as_u64().unwrap_or(0);
288 assert!(bytes > 30, "expected stdin bytes > 30, got {bytes}");
289 }
290}