#[cfg(feature = "embedder-candle")]
pub mod candle_embedder;
#[cfg(feature = "embedder-candle")]
pub use candle_embedder::{CandleEmbedder, CandleEmbedderError};
pub mod rss;
mod fast_embedder;
mod types;
#[cfg(any(test, feature = "embedder-test-support"))]
mod mock;
pub use fast_embedder::FastEmbedder;
pub use types::{
CudaOptions, DEFAULT_CACHE_CAPACITY, DEFAULT_CUDA_GPU_MEM_LIMIT_BYTES,
DEFAULT_ORT_INTER_THREADS, DEFAULT_ORT_INTRA_THREADS, EMBED_DIM, Embedder, ExecutionProvider,
OrtThreadingOptions, embed_one, resolve_cuda_options, resolve_expected_provider,
resolve_fastembed_cache_dir, resolve_ort_threading_options,
};
#[cfg(any(test, feature = "embedder-test-support"))]
pub use mock::MockEmbedder;
#[cfg(test)]
mod tests {
use super::*;
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn resolve_fastembed_cache_dir_prefers_env_vars() {
let _g = ENV_LOCK.lock().unwrap();
let prev_dir = std::env::var("FASTEMBED_CACHE_DIR").ok();
let prev_path = std::env::var("FASTEMBED_CACHE_PATH").ok();
unsafe {
std::env::set_var("FASTEMBED_CACHE_DIR", "/tmp/fast-dir-test");
std::env::remove_var("FASTEMBED_CACHE_PATH");
}
assert_eq!(
resolve_fastembed_cache_dir(),
std::path::PathBuf::from("/tmp/fast-dir-test"),
"FASTEMBED_CACHE_DIR must win when set"
);
unsafe {
std::env::remove_var("FASTEMBED_CACHE_DIR");
std::env::set_var("FASTEMBED_CACHE_PATH", "/tmp/fast-path-test");
}
assert_eq!(
resolve_fastembed_cache_dir(),
std::path::PathBuf::from("/tmp/fast-path-test"),
"FASTEMBED_CACHE_PATH must be honoured when FASTEMBED_CACHE_DIR is unset"
);
unsafe {
std::env::remove_var("FASTEMBED_CACHE_DIR");
std::env::remove_var("FASTEMBED_CACHE_PATH");
}
if let Some(home) = dirs::home_dir() {
assert_eq!(
resolve_fastembed_cache_dir(),
home.join(".cache").join("fastembed"),
"must fall back to $HOME/.cache/fastembed when no env vars set"
);
}
unsafe {
match prev_dir {
Some(v) => std::env::set_var("FASTEMBED_CACHE_DIR", v),
None => std::env::remove_var("FASTEMBED_CACHE_DIR"),
}
match prev_path {
Some(v) => std::env::set_var("FASTEMBED_CACHE_PATH", v),
None => std::env::remove_var("FASTEMBED_CACHE_PATH"),
}
}
}
struct EnvVarGuard {
key: &'static str,
prev: Option<String>,
}
impl EnvVarGuard {
fn apply(key: &'static str, value: Option<&str>) -> Self {
let prev = std::env::var(key).ok();
unsafe {
match value {
Some(v) => std::env::set_var(key, v),
None => std::env::remove_var(key),
}
}
Self { key, prev }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
unsafe {
match &self.prev {
Some(v) => std::env::set_var(self.key, v),
None => std::env::remove_var(self.key),
}
}
}
}
#[test]
fn cuda_options_default_limit() {
let _g = ENV_LOCK.lock().unwrap();
let _b = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_BYTES", None);
let _m = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_MB", None);
assert_eq!(
resolve_cuda_options().gpu_mem_limit_bytes,
DEFAULT_CUDA_GPU_MEM_LIMIT_BYTES,
"default CUDA gpu_mem_limit must be 12 GiB when no env knob is set"
);
assert_eq!(DEFAULT_CUDA_GPU_MEM_LIMIT_BYTES, 12 * 1024 * 1024 * 1024);
}
#[test]
fn cuda_options_bytes_env() {
let _g = ENV_LOCK.lock().unwrap();
let _b = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_BYTES", Some("8589934592"));
let _m = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_MB", None);
assert_eq!(
resolve_cuda_options().gpu_mem_limit_bytes,
8_589_934_592,
"explicit byte limit must be used verbatim"
);
}
#[test]
fn cuda_options_mb_env() {
let _g = ENV_LOCK.lock().unwrap();
let _b = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_BYTES", None);
let _m = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_MB", Some("4096"));
assert_eq!(
resolve_cuda_options().gpu_mem_limit_bytes,
4096usize * 1024 * 1024,
"MB limit must be scaled to bytes"
);
}
#[test]
fn cuda_options_bytes_takes_precedence() {
let _g = ENV_LOCK.lock().unwrap();
let _b = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_BYTES", Some("1073741824"));
let _m = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_MB", Some("4096"));
assert_eq!(
resolve_cuda_options().gpu_mem_limit_bytes,
1_073_741_824,
"BYTES knob must take precedence over MB"
);
}
#[test]
fn cuda_options_ignores_malformed() {
let _g = ENV_LOCK.lock().unwrap();
{
let _b = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_BYTES", Some("not-a-number"));
let _m = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_MB", Some("2048"));
assert_eq!(
resolve_cuda_options().gpu_mem_limit_bytes,
2048usize * 1024 * 1024,
"malformed BYTES must fall through to a valid MB knob"
);
}
{
let _b = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_BYTES", Some("0"));
let _m = EnvVarGuard::apply("TRUSTY_GPU_MEM_LIMIT_MB", Some("nope"));
assert_eq!(
resolve_cuda_options().gpu_mem_limit_bytes,
DEFAULT_CUDA_GPU_MEM_LIMIT_BYTES,
"zero BYTES + malformed MB must fall back to the safe default"
);
}
}
#[test]
fn ort_threading_defaults() {
let _g = ENV_LOCK.lock().unwrap();
let _i = EnvVarGuard::apply("TRUSTY_ORT_INTRA_THREADS", None);
let _e = EnvVarGuard::apply("TRUSTY_ORT_INTER_THREADS", None);
let _s = EnvVarGuard::apply("TRUSTY_ORT_ALLOW_SPINNING", None);
let opts = resolve_ort_threading_options();
assert_eq!(
opts.intra_threads, DEFAULT_ORT_INTRA_THREADS,
"default intra-op threads must be 1 to avoid the #1542 barrier deadlock"
);
assert_eq!(opts.inter_threads, DEFAULT_ORT_INTER_THREADS);
assert!(
!opts.allow_spinning,
"spinning must default to off — it is the busy-wait half of the deadlock"
);
assert_eq!(DEFAULT_ORT_INTRA_THREADS, 1);
assert_eq!(DEFAULT_ORT_INTER_THREADS, 1);
}
#[test]
fn ort_threading_reads_env() {
let _g = ENV_LOCK.lock().unwrap();
let _i = EnvVarGuard::apply("TRUSTY_ORT_INTRA_THREADS", Some("4"));
let _e = EnvVarGuard::apply("TRUSTY_ORT_INTER_THREADS", Some("2"));
let _s = EnvVarGuard::apply("TRUSTY_ORT_ALLOW_SPINNING", None);
let opts = resolve_ort_threading_options();
assert_eq!(
opts.intra_threads, 4,
"explicit intra-op count must be used"
);
assert_eq!(
opts.inter_threads, 2,
"explicit inter-op count must be used"
);
assert!(!opts.allow_spinning);
}
#[test]
fn ort_threading_ignores_malformed() {
let _g = ENV_LOCK.lock().unwrap();
let _i = EnvVarGuard::apply("TRUSTY_ORT_INTRA_THREADS", Some("not-a-number"));
let _e = EnvVarGuard::apply("TRUSTY_ORT_INTER_THREADS", Some("0"));
let _s = EnvVarGuard::apply("TRUSTY_ORT_ALLOW_SPINNING", None);
let opts = resolve_ort_threading_options();
assert_eq!(
opts.intra_threads, DEFAULT_ORT_INTRA_THREADS,
"malformed intra-op count must fall back to the safe default"
);
assert_eq!(
opts.inter_threads, DEFAULT_ORT_INTER_THREADS,
"zero inter-op count must fall back to the safe default"
);
}
#[test]
fn ort_threading_spinning_truthy() {
let _g = ENV_LOCK.lock().unwrap();
let _i = EnvVarGuard::apply("TRUSTY_ORT_INTRA_THREADS", None);
let _e = EnvVarGuard::apply("TRUSTY_ORT_INTER_THREADS", None);
{
let _s = EnvVarGuard::apply("TRUSTY_ORT_ALLOW_SPINNING", Some("TRUE"));
assert!(
resolve_ort_threading_options().allow_spinning,
"a truthy value (case-insensitive) must enable spinning"
);
}
{
let _s = EnvVarGuard::apply("TRUSTY_ORT_ALLOW_SPINNING", Some("maybe"));
assert!(
!resolve_ort_threading_options().allow_spinning,
"a non-truthy value must leave spinning disabled"
);
}
}
#[test]
fn resolve_expected_provider_forces_cpu() {
let _g = ENV_LOCK.lock().unwrap();
let _d = EnvVarGuard::apply("TRUSTY_DEVICE", Some("CPU"));
assert_eq!(resolve_expected_provider(), ExecutionProvider::Cpu);
}
#[test]
fn resolve_expected_provider_default_matches_platform() {
let _g = ENV_LOCK.lock().unwrap();
let _d = EnvVarGuard::apply("TRUSTY_DEVICE", None);
let _u = EnvVarGuard::apply("TRUSTY_COREML_COMPUTE_UNITS", None);
let got = resolve_expected_provider();
#[cfg(feature = "embedder-cuda")]
let expected = ExecutionProvider::Cuda;
#[cfg(all(
not(feature = "embedder-cuda"),
target_arch = "aarch64",
target_os = "macos"
))]
let expected = ExecutionProvider::CoreMLAne;
#[cfg(all(
not(feature = "embedder-cuda"),
not(all(target_arch = "aarch64", target_os = "macos"))
))]
let expected = ExecutionProvider::Cpu;
assert_eq!(
got, expected,
"predicted provider must match init_options for this build/platform"
);
}
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
#[cfg(not(feature = "embedder-cuda"))]
#[test]
fn resolve_expected_provider_coreml_units() {
let _g = ENV_LOCK.lock().unwrap();
let _d = EnvVarGuard::apply("TRUSTY_DEVICE", None);
{
let _u = EnvVarGuard::apply("TRUSTY_COREML_COMPUTE_UNITS", Some("all"));
assert_eq!(resolve_expected_provider(), ExecutionProvider::CoreML);
}
{
let _u = EnvVarGuard::apply("TRUSTY_COREML_COMPUTE_UNITS", None);
assert_eq!(resolve_expected_provider(), ExecutionProvider::CoreMLAne);
}
}
#[tokio::test]
async fn mock_embedder_round_trip() {
let e = MockEmbedder::new(EMBED_DIM);
assert_eq!(e.dimension(), EMBED_DIM);
let v = embed_one(&e, "hello").await.unwrap();
assert_eq!(v.len(), EMBED_DIM);
let batch = e
.embed_batch(&["a".to_string(), "b".to_string()])
.await
.unwrap();
assert_eq!(batch.len(), 2);
assert_ne!(batch[0], batch[1]);
}
#[tokio::test]
async fn mock_embedder_empty_input_returns_empty() {
let e = MockEmbedder::new(EMBED_DIM);
let v = e.embed_batch(&[]).await.unwrap();
assert!(v.is_empty());
}
#[tokio::test]
#[ignore]
async fn fastembed_returns_correct_dim() {
let e = FastEmbedder::new().await.unwrap();
assert_eq!(e.dimension(), 384);
let v = embed_one(&e, "fn authenticate(user: &str) -> bool")
.await
.unwrap();
assert_eq!(v.len(), 384);
assert!(v.iter().any(|x| *x != 0.0));
}
#[tokio::test]
#[ignore]
async fn fastembed_cache_hit_is_idempotent() {
let e = FastEmbedder::new().await.unwrap();
let v1 = embed_one(&e, "cached").await.unwrap();
let v2 = embed_one(&e, "cached").await.unwrap();
assert_eq!(v1, v2);
}
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
#[test]
fn trusty_device_cpu_disables_coreml_on_apple_silicon() {
use crate::embedder::fast_embedder::FastEmbedder as FE;
let _guard = ENV_LOCK.lock().unwrap();
let prev = std::env::var("TRUSTY_DEVICE").ok();
unsafe { std::env::set_var("TRUSTY_DEVICE", "cpu") };
let (_opts, provider) = FE::init_options(fastembed::EmbeddingModel::AllMiniLML6V2Q);
assert_eq!(
provider,
ExecutionProvider::Cpu,
"TRUSTY_DEVICE=cpu must suppress CoreML EP on Apple Silicon"
);
unsafe {
match prev {
Some(v) => std::env::set_var("TRUSTY_DEVICE", v),
None => std::env::remove_var("TRUSTY_DEVICE"),
}
}
}
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
#[test]
fn default_apple_silicon_uses_coreml_ane() {
use crate::embedder::fast_embedder::FastEmbedder as FE;
let _guard = ENV_LOCK.lock().unwrap();
let prev_device = std::env::var("TRUSTY_DEVICE").ok();
let prev_units = std::env::var("TRUSTY_COREML_COMPUTE_UNITS").ok();
unsafe {
std::env::remove_var("TRUSTY_DEVICE");
std::env::remove_var("TRUSTY_COREML_COMPUTE_UNITS");
}
let (_opts, provider) = FE::init_options(fastembed::EmbeddingModel::AllMiniLML6V2Q);
assert_eq!(
provider,
ExecutionProvider::CoreMLAne,
"default behaviour on Apple Silicon must register CoreML(ANE) — the OOM-safe replacement for CoreML(All)"
);
unsafe {
match prev_device {
Some(v) => std::env::set_var("TRUSTY_DEVICE", v),
None => std::env::remove_var("TRUSTY_DEVICE"),
}
match prev_units {
Some(v) => std::env::set_var("TRUSTY_COREML_COMPUTE_UNITS", v),
None => std::env::remove_var("TRUSTY_COREML_COMPUTE_UNITS"),
}
}
}
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
#[test]
fn coreml_compute_units_all_opt_in() {
use crate::embedder::fast_embedder::FastEmbedder as FE;
let _guard = ENV_LOCK.lock().unwrap();
let prev_device = std::env::var("TRUSTY_DEVICE").ok();
let prev_units = std::env::var("TRUSTY_COREML_COMPUTE_UNITS").ok();
unsafe {
std::env::remove_var("TRUSTY_DEVICE");
std::env::set_var("TRUSTY_COREML_COMPUTE_UNITS", "all");
}
let (_opts, provider) = FE::init_options(fastembed::EmbeddingModel::AllMiniLML6V2Q);
assert_eq!(
provider,
ExecutionProvider::CoreML,
"TRUSTY_COREML_COMPUTE_UNITS=all must select the CoreML(All) tag"
);
unsafe {
match prev_device {
Some(v) => std::env::set_var("TRUSTY_DEVICE", v),
None => std::env::remove_var("TRUSTY_DEVICE"),
}
match prev_units {
Some(v) => std::env::set_var("TRUSTY_COREML_COMPUTE_UNITS", v),
None => std::env::remove_var("TRUSTY_COREML_COMPUTE_UNITS"),
}
}
}
#[test]
fn zero_vector_guard_rejects_all_zero_batch() {
let zero_vec: Vec<f32> = vec![0.0; EMBED_DIM];
let non_zero_vec: Vec<f32> = {
let mut v = vec![0.0_f32; EMBED_DIM];
v[0] = 1.0;
v
};
assert!(
types::is_zero_vector(&zero_vec),
"synthetic zero vector must be detected by is_zero_vector"
);
assert!(
!types::is_zero_vector(&non_zero_vec),
"non-zero vector must NOT be detected by is_zero_vector"
);
let mock = MockEmbedder::new(EMBED_DIM);
let hash_result = mock.hash_to_vec("some text");
assert!(
!types::is_zero_vector(&hash_result),
"MockEmbedder must produce non-zero vectors for non-empty input"
);
}
}