use crate::error::{Error, Result};
use std::io::{BufRead, Write};
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
pub trait Embedder: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>>;
}
pub struct SubprocessEmbedder {
child: std::sync::Mutex<SubprocessState>,
}
struct SubprocessState {
_child: Child,
stdin: ChildStdin,
stdout: std::io::BufReader<ChildStdout>,
}
impl SubprocessEmbedder {
pub fn new(program: &str, args: &[&str]) -> Result<Self> {
let mut child = Command::new(program)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()
.map_err(|e| Error::InvalidInput(format!("failed to spawn embedder: {e}")))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| Error::InvalidInput("no stdin handle".into()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| Error::InvalidInput("no stdout handle".into()))?;
Ok(Self {
child: std::sync::Mutex::new(SubprocessState {
_child: child,
stdin,
stdout: std::io::BufReader::new(stdout),
}),
})
}
}
impl Embedder for SubprocessEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut state = self
.child
.lock()
.map_err(|_| Error::InvalidInput("embedder mutex poisoned".into()))?;
let sanitised = text.replace('\n', " ");
writeln!(state.stdin, "{sanitised}")
.map_err(|e| Error::InvalidInput(format!("write to embedder: {e}")))?;
let mut line = String::new();
state
.stdout
.read_line(&mut line)
.map_err(|e| Error::InvalidInput(format!("read from embedder: {e}")))?;
line.split_whitespace()
.map(|s| {
s.parse::<f32>()
.map_err(|e| Error::InvalidInput(format!("bad float from embedder: {e}")))
})
.collect()
}
}
pub struct FixedEmbedder(pub Vec<f32>);
impl Embedder for FixedEmbedder {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(self.0.clone())
}
}