use std::sync::Arc;
use async_trait::async_trait;
use lunaris_core::circuit_breaker::CircuitBreaker;
use lunaris_core::{Embedder, LunarisError, StorageError};
#[derive(Clone, Debug)]
pub struct EmbedderProviderId(String);
impl EmbedderProviderId {
#[must_use]
pub fn new(s: impl Into<String>) -> Self {
Self(s.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
pub struct FallbackEmbedder<P, F>
where
P: Embedder,
F: Embedder,
{
primary: P,
fallback: F,
breaker: Arc<CircuitBreaker>,
provider_id: EmbedderProviderId,
}
impl<P, F> FallbackEmbedder<P, F>
where
P: Embedder,
F: Embedder,
{
pub fn new(primary: P, fallback: F, provider_id: EmbedderProviderId) -> Self {
debug_assert_eq!(
primary.dim(),
fallback.dim(),
"FallbackEmbedder: primary.dim() {} != fallback.dim() {}",
primary.dim(),
fallback.dim()
);
Self { primary, fallback, breaker: Arc::new(CircuitBreaker::new()), provider_id }
}
#[must_use]
pub fn with_breaker(mut self, breaker: Arc<CircuitBreaker>) -> Self {
self.breaker = breaker;
self
}
#[must_use]
pub fn breaker(&self) -> &Arc<CircuitBreaker> {
&self.breaker
}
#[must_use]
pub fn provider_id(&self) -> &EmbedderProviderId {
&self.provider_id
}
}
#[async_trait]
impl<P, F> Embedder for FallbackEmbedder<P, F>
where
P: Embedder,
F: Embedder,
{
fn dim(&self) -> usize {
self.primary.dim()
}
async fn embed_batch(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, LunarisError> {
if self.breaker.allow_request() {
match self.primary.embed_batch(inputs).await {
Ok(out) => {
self.breaker.on_success();
tracing::trace!(
provider = %self.provider_id.as_str(),
"fallback_embedder.primary.success"
);
return Ok(out);
}
Err(e) if is_transient(&e) => {
self.breaker.on_failure();
tracing::warn!(
provider = %self.provider_id.as_str(),
error = %e,
"fallback_embedder.primary.transient_failure — routing to fallback"
);
}
Err(e) => {
tracing::warn!(
provider = %self.provider_id.as_str(),
error = %e,
"fallback_embedder.primary.terminal_failure — not retrying"
);
return Err(e);
}
}
} else {
tracing::debug!(
provider = %self.provider_id.as_str(),
"fallback_embedder.primary.tripped — breaker open"
);
}
self.fallback.embed_batch(inputs).await
}
}
#[must_use]
pub fn is_transient(err: &LunarisError) -> bool {
match err {
LunarisError::Storage(StorageError::Backend(_)) => true,
LunarisError::Storage(StorageError::NotSupported(_)) => false,
LunarisError::Storage(_) => false,
LunarisError::Extract(lunaris_core::ExtractError::Timeout) => true,
LunarisError::Extract(lunaris_core::ExtractError::Backend(_)) => true,
LunarisError::Extract(lunaris_core::ExtractError::GrammarReject(_)) => false,
LunarisError::Validate(_) => false,
LunarisError::Retrieve(_) => false,
LunarisError::Consolidate(_) => false,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
struct ScriptedEmbedder {
dim: usize,
results: Mutex<std::collections::VecDeque<Result<Vec<Vec<f32>>, LunarisError>>>,
calls: AtomicUsize,
}
impl ScriptedEmbedder {
fn new(dim: usize, results: Vec<Result<Vec<Vec<f32>>, LunarisError>>) -> Self {
Self { dim, results: Mutex::new(results.into()), calls: AtomicUsize::new(0) }
}
fn calls(&self) -> usize {
self.calls.load(Ordering::Relaxed)
}
}
#[async_trait]
impl Embedder for ScriptedEmbedder {
fn dim(&self) -> usize {
self.dim
}
async fn embed_batch(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, LunarisError> {
self.calls.fetch_add(1, Ordering::Relaxed);
self.results
.lock()
.unwrap()
.pop_front()
.unwrap_or_else(|| Ok(vec![vec![0.0_f32; self.dim]; inputs.len()]))
}
}
fn ok_vec(dim: usize) -> Result<Vec<Vec<f32>>, LunarisError> {
Ok(vec![vec![0.1; dim]])
}
fn transient_err() -> Result<Vec<Vec<f32>>, LunarisError> {
Err(LunarisError::Storage(StorageError::Backend("upstream 503".into())))
}
fn terminal_err() -> Result<Vec<Vec<f32>>, LunarisError> {
Err(LunarisError::Validate(lunaris_core::ValidateError::Temporal))
}
#[tokio::test]
async fn primary_success_skips_fallback() {
let primary = ScriptedEmbedder::new(768, vec![ok_vec(768)]);
let fallback = ScriptedEmbedder::new(768, vec![ok_vec(768)]);
let f = FallbackEmbedder::new(primary, fallback, EmbedderProviderId::new("test"));
let _ = f.embed_batch(&["hello"]).await.unwrap();
assert_eq!(f.primary.calls(), 1);
assert_eq!(f.fallback.calls(), 0);
}
#[tokio::test]
async fn transient_failure_routes_to_fallback() {
let primary = ScriptedEmbedder::new(768, vec![transient_err()]);
let fallback = ScriptedEmbedder::new(768, vec![ok_vec(768)]);
let f = FallbackEmbedder::new(primary, fallback, EmbedderProviderId::new("test"));
let _ = f.embed_batch(&["hello"]).await.unwrap();
assert_eq!(f.primary.calls(), 1);
assert_eq!(f.fallback.calls(), 1);
}
#[tokio::test]
async fn terminal_failure_propagates() {
let primary = ScriptedEmbedder::new(768, vec![terminal_err()]);
let fallback = ScriptedEmbedder::new(768, vec![ok_vec(768)]);
let f = FallbackEmbedder::new(primary, fallback, EmbedderProviderId::new("test"));
let err = f.embed_batch(&["hello"]).await.unwrap_err();
assert!(matches!(err, LunarisError::Validate(_)));
assert_eq!(f.primary.calls(), 1);
assert_eq!(f.fallback.calls(), 0, "terminal failure must not mask via fallback");
}
#[tokio::test]
async fn breaker_trips_after_threshold() {
let primary_calls: Vec<_> =
(0..6).map(|i| if i < 5 { transient_err() } else { ok_vec(768) }).collect();
let fallback_calls: Vec<_> = (0..6).map(|_| ok_vec(768)).collect();
let primary = ScriptedEmbedder::new(768, primary_calls);
let fallback = ScriptedEmbedder::new(768, fallback_calls);
let f = FallbackEmbedder::new(primary, fallback, EmbedderProviderId::new("test"));
for _ in 0..5 {
let _ = f.embed_batch(&["x"]).await.unwrap();
}
assert_eq!(f.primary.calls(), 5);
assert_eq!(f.fallback.calls(), 5);
let _ = f.embed_batch(&["x"]).await.unwrap();
assert_eq!(f.primary.calls(), 5, "breaker Open: primary skipped");
assert_eq!(f.fallback.calls(), 6);
}
#[test]
fn dim_returns_primary() {
let primary = ScriptedEmbedder::new(768, vec![]);
let fallback = ScriptedEmbedder::new(768, vec![]);
let f = FallbackEmbedder::new(primary, fallback, EmbedderProviderId::new("test"));
assert_eq!(f.dim(), 768);
}
#[test]
#[should_panic(expected = "FallbackEmbedder: primary.dim()")]
fn debug_assert_mismatched_dim_panics() {
let primary = ScriptedEmbedder::new(768, vec![]);
let fallback = ScriptedEmbedder::new(384, vec![]);
let _ = FallbackEmbedder::new(primary, fallback, EmbedderProviderId::new("test"));
}
}