use std::path::{Path, PathBuf};
use crate::error::OxideError;
pub struct MojoEmbedder {
#[allow(dead_code)] binary_path: PathBuf,
}
impl MojoEmbedder {
pub fn new(binary_path: impl AsRef<Path>) -> Self {
Self {
binary_path: binary_path.as_ref().to_path_buf(),
}
}
pub async fn embed_batch(
&self,
texts: Vec<String>,
) -> Result<Vec<Vec<f32>>, OxideError> {
#[cfg(feature = "mojo-interop")]
{
self.run_subprocess(texts).await
}
#[cfg(not(feature = "mojo-interop"))]
{
let _ = texts;
Err(OxideError::Other(
"MojoEmbedder requires the `mojo-interop` feature: \
oxide-agent = { features = [\"mojo-interop\"] }"
.into(),
))
}
}
#[cfg(feature = "mojo-interop")]
async fn run_subprocess(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, OxideError> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::process::Command;
let request = serde_json::json!({"texts": texts});
let request_line = serde_json::to_string(&request).map_err(OxideError::Serde)?;
let mut child = Command::new(&self.binary_path)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| OxideError::Other(format!("spawn mojo binary: {e}")))?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(request_line.as_bytes())
.await
.map_err(|e| OxideError::Other(format!("write to mojo stdin: {e}")))?;
stdin
.write_all(b"\n")
.await
.map_err(|e| OxideError::Other(format!("write newline: {e}")))?;
}
let mut stdout_buf = String::new();
if let Some(mut stdout) = child.stdout.take() {
stdout
.read_to_string(&mut stdout_buf)
.await
.map_err(|e| OxideError::Other(format!("read mojo stdout: {e}")))?;
}
let status = child
.wait()
.await
.map_err(|e| OxideError::Other(format!("wait for mojo: {e}")))?;
if !status.success() {
return Err(OxideError::Other(format!(
"mojo binary exited with status {status}"
)));
}
let line = stdout_buf
.lines()
.find(|l| !l.trim().is_empty())
.ok_or_else(|| OxideError::Other("mojo binary produced no output".into()))?;
#[derive(serde::Deserialize)]
struct MojoResponse {
embeddings: Vec<Vec<f32>>,
}
let resp: MojoResponse =
serde_json::from_str(line).map_err(OxideError::Serde)?;
Ok(resp.embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn stub_returns_meaningful_error_without_feature() {
#[cfg(not(feature = "mojo-interop"))]
{
let embedder = MojoEmbedder::new("/usr/local/bin/mojo-embed");
let err = embedder
.embed_batch(vec!["hello".into()])
.await
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("mojo-interop"),
"error should mention the feature flag"
);
}
#[cfg(feature = "mojo-interop")]
{}
}
}