1use std::path::{Path, PathBuf};
12
13use crate::error::OxideError;
14
15pub struct MojoEmbedder {
24 #[allow(dead_code)] binary_path: PathBuf,
26}
27
28impl MojoEmbedder {
29 pub fn new(binary_path: impl AsRef<Path>) -> Self {
31 Self {
32 binary_path: binary_path.as_ref().to_path_buf(),
33 }
34 }
35
36 pub async fn embed_batch(
45 &self,
46 texts: Vec<String>,
47 ) -> Result<Vec<Vec<f32>>, OxideError> {
48 #[cfg(feature = "mojo-interop")]
49 {
50 self.run_subprocess(texts).await
51 }
52 #[cfg(not(feature = "mojo-interop"))]
53 {
54 let _ = texts;
55 Err(OxideError::Other(
56 "MojoEmbedder requires the `mojo-interop` feature: \
57 oxide-agent = { features = [\"mojo-interop\"] }"
58 .into(),
59 ))
60 }
61 }
62
63 #[cfg(feature = "mojo-interop")]
64 async fn run_subprocess(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, OxideError> {
65 use tokio::io::{AsyncReadExt, AsyncWriteExt};
66 use tokio::process::Command;
67
68 let request = serde_json::json!({"texts": texts});
69 let request_line = serde_json::to_string(&request).map_err(OxideError::Serde)?;
70
71 let mut child = Command::new(&self.binary_path)
72 .stdin(std::process::Stdio::piped())
73 .stdout(std::process::Stdio::piped())
74 .stderr(std::process::Stdio::piped())
75 .spawn()
76 .map_err(|e| OxideError::Other(format!("spawn mojo binary: {e}")))?;
77
78 if let Some(mut stdin) = child.stdin.take() {
80 stdin
81 .write_all(request_line.as_bytes())
82 .await
83 .map_err(|e| OxideError::Other(format!("write to mojo stdin: {e}")))?;
84 stdin
85 .write_all(b"\n")
86 .await
87 .map_err(|e| OxideError::Other(format!("write newline: {e}")))?;
88 }
89
90 let mut stdout_buf = String::new();
92 if let Some(mut stdout) = child.stdout.take() {
93 stdout
94 .read_to_string(&mut stdout_buf)
95 .await
96 .map_err(|e| OxideError::Other(format!("read mojo stdout: {e}")))?;
97 }
98
99 let status = child
100 .wait()
101 .await
102 .map_err(|e| OxideError::Other(format!("wait for mojo: {e}")))?;
103
104 if !status.success() {
105 return Err(OxideError::Other(format!(
106 "mojo binary exited with status {status}"
107 )));
108 }
109
110 let line = stdout_buf
112 .lines()
113 .find(|l| !l.trim().is_empty())
114 .ok_or_else(|| OxideError::Other("mojo binary produced no output".into()))?;
115
116 #[derive(serde::Deserialize)]
117 struct MojoResponse {
118 embeddings: Vec<Vec<f32>>,
119 }
120
121 let resp: MojoResponse =
122 serde_json::from_str(line).map_err(OxideError::Serde)?;
123
124 Ok(resp.embeddings)
125 }
126}
127
128#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[tokio::test]
135 async fn stub_returns_meaningful_error_without_feature() {
136 #[cfg(not(feature = "mojo-interop"))]
137 {
138 let embedder = MojoEmbedder::new("/usr/local/bin/mojo-embed");
139 let err = embedder
140 .embed_batch(vec!["hello".into()])
141 .await
142 .unwrap_err();
143 let msg = err.to_string();
144 assert!(
145 msg.contains("mojo-interop"),
146 "error should mention the feature flag"
147 );
148 }
149 #[cfg(feature = "mojo-interop")]
150 {}
151 }
152}