Skip to main content

oxide_agent/mojo/
mod.rs

1// Mojo Interoperability — offload batch embedding to a Mojo-compiled binary.
2//
3// Protocol (line-delimited JSON over stdin/stdout):
4//
5//   stdin  → {"texts": ["hello", "world"]}
6//   stdout ← {"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...]]}
7//
8// Enable with the `mojo-interop` feature flag:
9//   oxide-agent = { features = ["mojo-interop"] }
10
11use std::path::{Path, PathBuf};
12
13use crate::error::OxideError;
14
15// ── MojoEmbedder ─────────────────────────────────────────────────────────────
16
17/// Bridge to a Mojo-compiled embedding binary for high-throughput batch
18/// embedding.
19///
20/// The binary must implement the JSON line protocol described at the top of
21/// this file.  Spawn it once via [`MojoEmbedder::new`] and reuse it across
22/// calls — each [`embed_batch`] call is a single subprocess invocation.
23pub struct MojoEmbedder {
24    #[allow(dead_code)] // used only under the `mojo-interop` feature
25    binary_path: PathBuf,
26}
27
28impl MojoEmbedder {
29    /// Create a new embedder pointing at the given Mojo binary.
30    pub fn new(binary_path: impl AsRef<Path>) -> Self {
31        Self {
32            binary_path: binary_path.as_ref().to_path_buf(),
33        }
34    }
35
36    /// Embed a batch of strings.
37    ///
38    /// Spawns the Mojo binary, feeds it a single-line JSON request on stdin,
39    /// and parses the single-line JSON response from stdout.
40    ///
41    /// # Errors
42    /// Returns an error if the binary cannot be spawned, exits non-zero, or
43    /// returns malformed JSON.
44    pub async fn embed_batch(
45        &self,
46        texts: Vec<String>,
47    ) -> Result<Vec<Vec<f32>>, OxideError> {
48        #[cfg(feature = "mojo-interop")]
49        {
50            self.run_subprocess(texts).await
51        }
52        #[cfg(not(feature = "mojo-interop"))]
53        {
54            let _ = texts;
55            Err(OxideError::Other(
56                "MojoEmbedder requires the `mojo-interop` feature: \
57                 oxide-agent = { features = [\"mojo-interop\"] }"
58                    .into(),
59            ))
60        }
61    }
62
63    #[cfg(feature = "mojo-interop")]
64    async fn run_subprocess(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, OxideError> {
65        use tokio::io::{AsyncReadExt, AsyncWriteExt};
66        use tokio::process::Command;
67
68        let request = serde_json::json!({"texts": texts});
69        let request_line = serde_json::to_string(&request).map_err(OxideError::Serde)?;
70
71        let mut child = Command::new(&self.binary_path)
72            .stdin(std::process::Stdio::piped())
73            .stdout(std::process::Stdio::piped())
74            .stderr(std::process::Stdio::piped())
75            .spawn()
76            .map_err(|e| OxideError::Other(format!("spawn mojo binary: {e}")))?;
77
78        // Write request to stdin and close it.
79        if let Some(mut stdin) = child.stdin.take() {
80            stdin
81                .write_all(request_line.as_bytes())
82                .await
83                .map_err(|e| OxideError::Other(format!("write to mojo stdin: {e}")))?;
84            stdin
85                .write_all(b"\n")
86                .await
87                .map_err(|e| OxideError::Other(format!("write newline: {e}")))?;
88        }
89
90        // Read stdout.
91        let mut stdout_buf = String::new();
92        if let Some(mut stdout) = child.stdout.take() {
93            stdout
94                .read_to_string(&mut stdout_buf)
95                .await
96                .map_err(|e| OxideError::Other(format!("read mojo stdout: {e}")))?;
97        }
98
99        let status = child
100            .wait()
101            .await
102            .map_err(|e| OxideError::Other(format!("wait for mojo: {e}")))?;
103
104        if !status.success() {
105            return Err(OxideError::Other(format!(
106                "mojo binary exited with status {status}"
107            )));
108        }
109
110        // Parse the first non-empty line as JSON.
111        let line = stdout_buf
112            .lines()
113            .find(|l| !l.trim().is_empty())
114            .ok_or_else(|| OxideError::Other("mojo binary produced no output".into()))?;
115
116        #[derive(serde::Deserialize)]
117        struct MojoResponse {
118            embeddings: Vec<Vec<f32>>,
119        }
120
121        let resp: MojoResponse =
122            serde_json::from_str(line).map_err(OxideError::Serde)?;
123
124        Ok(resp.embeddings)
125    }
126}
127
128// ── Tests ─────────────────────────────────────────────────────────────────────
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[tokio::test]
135    async fn stub_returns_meaningful_error_without_feature() {
136        #[cfg(not(feature = "mojo-interop"))]
137        {
138            let embedder = MojoEmbedder::new("/usr/local/bin/mojo-embed");
139            let err = embedder
140                .embed_batch(vec!["hello".into()])
141                .await
142                .unwrap_err();
143            let msg = err.to_string();
144            assert!(
145                msg.contains("mojo-interop"),
146                "error should mention the feature flag"
147            );
148        }
149        #[cfg(feature = "mojo-interop")]
150        {}
151    }
152}