use std::num::NonZeroUsize;
use std::sync::Arc;
use anyhow::{Context, Result};
use async_trait::async_trait;
use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions};
use lru::LruCache;
use parking_lot::Mutex;
pub const EMBED_DIM: usize = 384;
pub const DEFAULT_CACHE_CAPACITY: usize = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExecutionProvider {
Cpu,
CoreML,
CoreMLAne,
Cuda,
}
impl ExecutionProvider {
pub fn as_str(&self) -> &'static str {
match self {
ExecutionProvider::Cpu => "CPU",
ExecutionProvider::CoreML => "CoreML",
ExecutionProvider::CoreMLAne => "CoreML(ANE)",
ExecutionProvider::Cuda => "CUDA",
}
}
}
impl std::fmt::Display for ExecutionProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[async_trait]
pub trait Embedder: Send + Sync {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
fn dimension(&self) -> usize;
fn provider(&self) -> ExecutionProvider {
ExecutionProvider::Cpu
}
}
pub async fn embed_one(embedder: &dyn Embedder, text: &str) -> Result<Vec<f32>> {
let mut v = embedder.embed_batch(&[text.to_string()]).await?;
v.pop()
.context("embedder returned no embedding for non-empty input")
}
pub struct FastEmbedder {
model: Arc<Mutex<TextEmbedding>>,
cache: Arc<Mutex<LruCache<String, Vec<f32>>>>,
dim: usize,
provider: ExecutionProvider,
}
impl FastEmbedder {
pub async fn new() -> Result<Self> {
Self::with_cache_size(DEFAULT_CACHE_CAPACITY).await
}
pub fn provider(&self) -> ExecutionProvider {
self.provider
}
fn init_options(model: EmbeddingModel) -> (TextInitOptions, ExecutionProvider) {
use ort::execution_providers::ExecutionProviderDispatch;
let opts = TextInitOptions::new(model);
let cpu_no_arena: ExecutionProviderDispatch =
ort::ep::CPU::default().with_arena_allocator(false).build();
#[cfg(feature = "embedder-cuda")]
{
let force_cpu = std::env::var("TRUSTY_DEVICE")
.map(|v| v.eq_ignore_ascii_case("cpu"))
.unwrap_or(false);
if !force_cpu {
let cuda: ExecutionProviderDispatch = ort::ep::CUDA::default().build();
let providers: Vec<ExecutionProviderDispatch> = vec![cuda, cpu_no_arena];
tracing::info!(
"trusty-embedder: registering CUDA + CPU(no-arena) execution providers \
(will fall back to CPU at session-init if no CUDA device is available)"
);
return (
opts.with_execution_providers(providers),
ExecutionProvider::Cuda,
);
}
tracing::info!(
"trusty-embedder: TRUSTY_DEVICE=cpu set — skipping CUDA EP registration"
);
}
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
{
let force_cpu = std::env::var("TRUSTY_DEVICE")
.map(|v| v.eq_ignore_ascii_case("cpu"))
.unwrap_or(false);
if !force_cpu {
use ort::ep::coreml::{ComputeUnits, SpecializationStrategy};
let (units, units_tag) = match std::env::var("TRUSTY_COREML_COMPUTE_UNITS")
.ok()
.as_deref()
.map(|s| s.trim().to_ascii_lowercase())
.as_deref()
{
Some("all") => (ComputeUnits::All, ExecutionProvider::CoreML),
Some("cpu_gpu") | Some("cpuandgpu") => {
(ComputeUnits::CPUAndGPU, ExecutionProvider::CoreML)
}
Some("cpu_only") | Some("cpuonly") => {
(ComputeUnits::CPUOnly, ExecutionProvider::CoreMLAne)
}
_ => (
ComputeUnits::CPUAndNeuralEngine,
ExecutionProvider::CoreMLAne,
),
};
let cache_dir = std::env::var("HOME")
.map(|h| format!("{}/Library/Caches/trusty-embedder/coreml", h))
.unwrap_or_else(|_| "/tmp/trusty-embedder-coreml".to_string());
let _ = std::fs::create_dir_all(&cache_dir);
let coreml: ExecutionProviderDispatch = ort::ep::CoreML::default()
.with_compute_units(units)
.with_static_input_shapes(true)
.with_specialization_strategy(SpecializationStrategy::FastPrediction)
.with_model_cache_dir(cache_dir.clone())
.build();
let providers: Vec<ExecutionProviderDispatch> = vec![coreml, cpu_no_arena];
let units_str = match units {
ComputeUnits::All => "all",
ComputeUnits::CPUAndGPU => "cpu_gpu",
ComputeUnits::CPUOnly => "cpu_only",
ComputeUnits::CPUAndNeuralEngine => "cpu_ane",
};
tracing::info!(
"trusty-embedder: registering CoreML (compute_units={}, static_shapes=true, \
cache={}) + CPU(no-arena) execution providers (Apple Silicon)",
units_str,
cache_dir,
);
return (opts.with_execution_providers(providers), units_tag);
}
tracing::info!(
"trusty-embedder: TRUSTY_DEVICE=cpu set — skipping CoreML EP registration (Apple Silicon)"
);
}
#[allow(unreachable_code)]
{
tracing::info!("trusty-embedder: registering CPU(no-arena) execution provider");
let providers: Vec<ExecutionProviderDispatch> = vec![cpu_no_arena];
(
opts.with_execution_providers(providers),
ExecutionProvider::Cpu,
)
}
}
pub async fn with_cache_size(capacity: usize) -> Result<Self> {
let capacity =
NonZeroUsize::new(capacity.max(1)).expect("capacity.max(1) is always non-zero");
let (model, provider) =
tokio::task::spawn_blocking(|| -> Result<(TextEmbedding, ExecutionProvider)> {
let require_gpu = std::env::var("TRUSTY_DEVICE")
.map(|v| v.eq_ignore_ascii_case("gpu"))
.unwrap_or(false);
let (q_opts, q_provider) = Self::init_options(EmbeddingModel::AllMiniLML6V2Q);
let (m, provider) = match TextEmbedding::try_new(q_opts) {
Ok(m) => (m, q_provider),
Err(q_err) => {
if q_provider != ExecutionProvider::Cpu && !require_gpu {
tracing::warn!(
"{} EP init failed ({q_err:#}); retrying with CPU-only \
execution provider",
q_provider
);
unsafe { std::env::set_var("TRUSTY_DEVICE", "cpu") };
let (cpu_opts, cpu_provider) =
Self::init_options(EmbeddingModel::AllMiniLML6V2Q);
match TextEmbedding::try_new(cpu_opts) {
Ok(m) => (m, cpu_provider),
Err(cpu_err) => {
tracing::warn!(
"AllMiniLML6V2Q init failed on CPU ({cpu_err:#}), \
falling back to AllMiniLML6V2"
);
let (fb_opts, fb_provider) =
Self::init_options(EmbeddingModel::AllMiniLML6V2);
let m = TextEmbedding::try_new(fb_opts).context(
"failed to initialise fastembed (tried CUDA→CPU on AllMiniLML6V2Q, then AllMiniLML6V2)",
)?;
(m, fb_provider)
}
}
} else if require_gpu {
return Err(anyhow::anyhow!(
"TRUSTY_DEVICE=gpu requested but accelerated execution provider \
failed to initialise: {q_err:#}"
));
} else {
tracing::warn!(
"AllMiniLML6V2Q init failed ({q_err:#}), falling back to AllMiniLML6V2"
);
let (fb_opts, fb_provider) =
Self::init_options(EmbeddingModel::AllMiniLML6V2);
let m = TextEmbedding::try_new(fb_opts).context(
"failed to initialise fastembed (tried AllMiniLML6V2Q and AllMiniLML6V2)",
)?;
(m, fb_provider)
}
}
};
let mut m = m;
let warmup: Vec<&str> = vec![
"hello world",
"the quick brown fox",
"memory palace warmup",
"embedding model ready",
"trusty common warmup",
];
let _ = m
.embed(warmup, None)
.context("fastembed warmup batch failed")?;
Ok((m, provider))
})
.await
.context("spawn_blocking joined with error during embedder init")??;
tracing::info!(
"trusty-embedder: FastEmbedder ready (provider={}, dim={})",
provider,
EMBED_DIM
);
Ok(Self {
model: Arc::new(Mutex::new(model)),
cache: Arc::new(Mutex::new(LruCache::new(capacity))),
dim: EMBED_DIM,
provider,
})
}
}
#[async_trait]
impl Embedder for FastEmbedder {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut to_compute: Vec<(usize, String)> = Vec::new();
{
let mut cache = self.cache.lock();
for (i, t) in texts.iter().enumerate() {
if let Some(v) = cache.get(t) {
results[i] = Some(v.clone());
} else {
to_compute.push((i, t.clone()));
}
}
}
if !to_compute.is_empty() {
let model = Arc::clone(&self.model);
let owned: Vec<String> = to_compute.iter().map(|(_, s)| s.clone()).collect();
let computed = tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
let mut guard = model.lock();
guard
.embed(owned, None)
.context("fastembed embed call failed")
})
.await
.context("spawn_blocking joined with error during embed")??;
if computed.len() != to_compute.len() {
anyhow::bail!(
"fastembed returned {} embeddings, expected {}",
computed.len(),
to_compute.len()
);
}
let mut cache = self.cache.lock();
for ((idx, key), vector) in to_compute.into_iter().zip(computed.into_iter()) {
cache.put(key, vector.clone());
results[idx] = Some(vector);
}
}
results
.into_iter()
.map(|opt| opt.context("missing embedding slot after batch"))
.collect()
}
fn dimension(&self) -> usize {
self.dim
}
fn provider(&self) -> ExecutionProvider {
self.provider
}
}
#[cfg(any(test, feature = "embedder-test-support"))]
pub struct MockEmbedder {
dim: usize,
}
#[cfg(any(test, feature = "embedder-test-support"))]
impl MockEmbedder {
pub fn new(dim: usize) -> Self {
Self { dim }
}
fn hash_to_vec(&self, text: &str) -> Vec<f32> {
let mut v = vec![0.0_f32; self.dim];
for (i, b) in text.bytes().enumerate() {
let slot = (i + b as usize) % self.dim;
v[slot] += (b as f32) / 255.0;
}
if let Some(first) = text.bytes().next() {
v[0] += first as f32 / 255.0;
}
v
}
}
#[cfg(any(test, feature = "embedder-test-support"))]
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|t| self.hash_to_vec(t)).collect())
}
fn dimension(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn mock_embedder_round_trip() {
let e = MockEmbedder::new(EMBED_DIM);
assert_eq!(e.dimension(), EMBED_DIM);
let v = embed_one(&e, "hello").await.unwrap();
assert_eq!(v.len(), EMBED_DIM);
let batch = e
.embed_batch(&["a".to_string(), "b".to_string()])
.await
.unwrap();
assert_eq!(batch.len(), 2);
assert_ne!(batch[0], batch[1]);
}
#[tokio::test]
async fn mock_embedder_empty_input_returns_empty() {
let e = MockEmbedder::new(EMBED_DIM);
let v = e.embed_batch(&[]).await.unwrap();
assert!(v.is_empty());
}
#[tokio::test]
#[ignore]
async fn fastembed_returns_correct_dim() {
let e = FastEmbedder::new().await.unwrap();
assert_eq!(e.dimension(), 384);
let v = embed_one(&e, "fn authenticate(user: &str) -> bool")
.await
.unwrap();
assert_eq!(v.len(), 384);
assert!(v.iter().any(|x| *x != 0.0));
}
#[tokio::test]
#[ignore]
async fn fastembed_cache_hit_is_idempotent() {
let e = FastEmbedder::new().await.unwrap();
let v1 = embed_one(&e, "cached").await.unwrap();
let v2 = embed_one(&e, "cached").await.unwrap();
assert_eq!(v1, v2);
}
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
#[test]
fn trusty_device_cpu_disables_coreml_on_apple_silicon() {
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
let _guard = ENV_LOCK.lock().unwrap();
let prev = std::env::var("TRUSTY_DEVICE").ok();
unsafe { std::env::set_var("TRUSTY_DEVICE", "cpu") };
let (_opts, provider) = FastEmbedder::init_options(EmbeddingModel::AllMiniLML6V2Q);
assert_eq!(
provider,
ExecutionProvider::Cpu,
"TRUSTY_DEVICE=cpu must suppress CoreML EP on Apple Silicon"
);
unsafe {
match prev {
Some(v) => std::env::set_var("TRUSTY_DEVICE", v),
None => std::env::remove_var("TRUSTY_DEVICE"),
}
}
}
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
#[test]
fn default_apple_silicon_uses_coreml_ane() {
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
let _guard = ENV_LOCK.lock().unwrap();
let prev_device = std::env::var("TRUSTY_DEVICE").ok();
let prev_units = std::env::var("TRUSTY_COREML_COMPUTE_UNITS").ok();
unsafe {
std::env::remove_var("TRUSTY_DEVICE");
std::env::remove_var("TRUSTY_COREML_COMPUTE_UNITS");
}
let (_opts, provider) = FastEmbedder::init_options(EmbeddingModel::AllMiniLML6V2Q);
assert_eq!(
provider,
ExecutionProvider::CoreMLAne,
"default behaviour on Apple Silicon must register CoreML(ANE) — the OOM-safe replacement for CoreML(All)"
);
unsafe {
match prev_device {
Some(v) => std::env::set_var("TRUSTY_DEVICE", v),
None => std::env::remove_var("TRUSTY_DEVICE"),
}
match prev_units {
Some(v) => std::env::set_var("TRUSTY_COREML_COMPUTE_UNITS", v),
None => std::env::remove_var("TRUSTY_COREML_COMPUTE_UNITS"),
}
}
}
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
#[test]
fn coreml_compute_units_all_opt_in() {
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
let _guard = ENV_LOCK.lock().unwrap();
let prev_device = std::env::var("TRUSTY_DEVICE").ok();
let prev_units = std::env::var("TRUSTY_COREML_COMPUTE_UNITS").ok();
unsafe {
std::env::remove_var("TRUSTY_DEVICE");
std::env::set_var("TRUSTY_COREML_COMPUTE_UNITS", "all");
}
let (_opts, provider) = FastEmbedder::init_options(EmbeddingModel::AllMiniLML6V2Q);
assert_eq!(
provider,
ExecutionProvider::CoreML,
"TRUSTY_COREML_COMPUTE_UNITS=all must select the CoreML(All) tag"
);
unsafe {
match prev_device {
Some(v) => std::env::set_var("TRUSTY_DEVICE", v),
None => std::env::remove_var("TRUSTY_DEVICE"),
}
match prev_units {
Some(v) => std::env::set_var("TRUSTY_COREML_COMPUTE_UNITS", v),
None => std::env::remove_var("TRUSTY_COREML_COMPUTE_UNITS"),
}
}
}
}