#[cfg(feature = "alloc")]
use alloc::{string::String, vec::Vec};
#[cfg(all(feature = "alloc", feature = "std"))]
use alloc::string::ToString;
use hashbrown::HashMap;
#[cfg(feature = "std")]
use super::error::{OxiRouterError, Result};
use super::query_log::QueryLog;
use super::source::DataSource;
use crate::context::ContextProvider;
#[cfg(feature = "ml")]
use crate::ml::Model;
#[cfg(feature = "rl")]
use crate::rl::Policy;
#[cfg(feature = "cache")]
use crate::cache::CacheManager;
mod cache;
mod federation;
mod operations;
mod routing;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ScoreComponent {
pub name: String,
pub weight: f32,
pub raw_value: f32,
pub contribution: f32,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RoutingExplanation {
pub source_id: String,
pub total_score: f32,
pub components: Vec<ScoreComponent>,
}
#[cfg(feature = "cache")]
const fn default_cache_enabled() -> bool {
true
}
#[cfg(feature = "cache")]
const fn default_cache_ttl_ms() -> u64 {
300_000
}
#[cfg(feature = "cache")]
const fn default_cache_max_entries() -> usize {
1000
}
const fn default_max_response_bytes() -> u64 {
64 * 1024 * 1024 }
#[allow(clippy::unnecessary_wraps)]
fn default_now_ms() -> Option<fn() -> u64> {
#[cfg(feature = "std")]
{
Some(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
})
}
#[cfg(not(feature = "std"))]
{
None
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub cooldown_ms: u64,
#[serde(skip, default = "default_now_ms")]
pub now_ms: Option<fn() -> u64>,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
cooldown_ms: 30_000,
#[cfg(feature = "std")]
now_ms: Some(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}),
#[cfg(not(feature = "std"))]
now_ms: None,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RouterConfig {
pub max_sources: usize,
pub min_confidence: f32,
pub use_ml: bool,
pub use_context: bool,
pub timeout_us: u64,
pub history_weight: f32,
pub vocab_weight: f32,
pub geo_weight: f32,
pub circuit_breaker: CircuitBreakerConfig,
#[serde(default = "default_max_response_bytes")]
pub max_response_bytes: u64,
#[cfg(feature = "cache")]
#[serde(default = "default_cache_enabled")]
pub cache_enabled: bool,
#[cfg(feature = "cache")]
#[serde(default = "default_cache_ttl_ms")]
pub cache_ttl_ms: u64,
#[cfg(feature = "cache")]
#[serde(default = "default_cache_max_entries")]
pub cache_max_entries: usize,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
max_sources: 5,
min_confidence: 0.1,
use_ml: true,
use_context: true,
timeout_us: 100_000, history_weight: 0.3,
vocab_weight: 0.4,
geo_weight: 0.3,
circuit_breaker: CircuitBreakerConfig::default(),
max_response_bytes: default_max_response_bytes(),
#[cfg(feature = "cache")]
cache_enabled: true,
#[cfg(feature = "cache")]
cache_ttl_ms: 300_000, #[cfg(feature = "cache")]
cache_max_entries: 1000,
}
}
}
#[cfg(feature = "std")]
impl RouterConfig {
pub fn from_config_file<P: AsRef<std::path::Path>>(path: P) -> Result<RouterConfig> {
let bytes =
std::fs::read(path).map_err(|e| OxiRouterError::InvalidSource(e.to_string()))?;
serde_json::from_slice(&bytes).map_err(|e| OxiRouterError::InvalidSource(e.to_string()))
}
}
pub struct Router<C: ContextProvider = crate::context::DefaultContextProvider> {
pub(super) sources: HashMap<String, DataSource>,
pub(super) config: RouterConfig,
pub(super) context_provider: C,
#[cfg(feature = "ml")]
pub(super) model: Option<Box<dyn Model>>,
#[cfg(feature = "ml")]
pub(super) model_bytes_cache: Option<Vec<u8>>,
#[cfg(feature = "ml")]
pub(super) online_training_enabled: bool,
#[cfg(feature = "rl")]
pub(super) policy: Option<Policy>,
pub(super) query_log: QueryLog,
#[cfg(feature = "cache")]
pub(super) cache: CacheManager,
#[cfg(feature = "http")]
pub(super) planner: Box<dyn crate::federation::planner::FederatedPlanner>,
}
impl Router<crate::context::DefaultContextProvider> {
#[must_use]
pub fn new() -> Self {
Self {
sources: HashMap::new(),
config: RouterConfig::default(),
context_provider: crate::context::DefaultContextProvider,
#[cfg(feature = "ml")]
model: None,
#[cfg(feature = "ml")]
model_bytes_cache: None,
#[cfg(feature = "ml")]
online_training_enabled: true,
#[cfg(feature = "rl")]
policy: None,
query_log: QueryLog::new(),
#[cfg(feature = "cache")]
cache: CacheManager::new(),
#[cfg(feature = "http")]
planner: Box::new(crate::federation::planner::DefaultPlanner::default()),
}
}
#[cfg(feature = "std")]
pub fn with_config_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let cfg = RouterConfig::from_config_file(path)?;
Ok(Self::with_config(
cfg,
crate::context::DefaultContextProvider,
))
}
}
impl Default for Router<crate::context::DefaultContextProvider> {
fn default() -> Self {
Self::new()
}
}
impl<C: ContextProvider> Router<C> {
#[must_use]
pub fn with_context_provider(context_provider: C) -> Self {
Self {
sources: HashMap::new(),
config: RouterConfig::default(),
context_provider,
#[cfg(feature = "ml")]
model: None,
#[cfg(feature = "ml")]
model_bytes_cache: None,
#[cfg(feature = "ml")]
online_training_enabled: true,
#[cfg(feature = "rl")]
policy: None,
query_log: QueryLog::new(),
#[cfg(feature = "cache")]
cache: CacheManager::new(),
#[cfg(feature = "http")]
planner: Box::new(crate::federation::planner::DefaultPlanner::default()),
}
}
#[must_use]
pub fn with_config(config: RouterConfig, context_provider: C) -> Self {
#[cfg(feature = "cache")]
let cache = CacheManager::with_config(
config.cache_max_entries,
config.cache_ttl_ms,
crate::cache::DEFAULT_CONTEXT_TTL_MS,
crate::cache::DEFAULT_SOURCE_TTL_MS,
);
Self {
sources: HashMap::new(),
config,
context_provider,
#[cfg(feature = "ml")]
model: None,
#[cfg(feature = "ml")]
model_bytes_cache: None,
#[cfg(feature = "ml")]
online_training_enabled: true,
#[cfg(feature = "rl")]
policy: None,
query_log: QueryLog::new(),
#[cfg(feature = "cache")]
cache,
#[cfg(feature = "http")]
planner: Box::new(crate::federation::planner::DefaultPlanner::default()),
}
}
pub fn add_source(&mut self, source: DataSource) {
self.sources.insert(source.id.clone(), source);
}
pub fn remove_source(&mut self, id: &str) -> Option<DataSource> {
self.sources.remove(id)
}
#[must_use]
pub fn get_source(&self, id: &str) -> Option<&DataSource> {
self.sources.get(id)
}
#[must_use]
pub fn get_source_mut(&mut self, id: &str) -> Option<&mut DataSource> {
self.sources.get_mut(id)
}
pub fn sources(&self) -> impl Iterator<Item = &DataSource> {
self.sources.values()
}
#[must_use]
pub fn source_count(&self) -> usize {
self.sources.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::error::OxiRouterError;
use crate::core::query::Query;
#[test]
fn test_router_creation() {
let router = Router::new();
assert_eq!(router.source_count(), 0);
}
#[test]
fn test_add_source() {
let mut router = Router::new();
router.add_source(DataSource::new("test", "http://example.com/sparql"));
assert_eq!(router.source_count(), 1);
assert!(router.get_source("test").is_some());
}
#[test]
fn test_route_no_sources() {
let router = Router::new();
let query = Query::parse("SELECT ?s WHERE { ?s ?p ?o }").unwrap();
let result = router.route(&query);
assert!(matches!(result, Err(OxiRouterError::NoSources { .. })));
}
#[test]
fn test_route_with_sources() {
let mut router = Router::new();
router.add_source(
DataSource::new("dbpedia", "http://dbpedia.org/sparql")
.with_vocabulary("http://schema.org/"),
);
router.add_source(
DataSource::new("wikidata", "http://wikidata.org/sparql")
.with_vocabulary("http://www.wikidata.org/"),
);
let query = Query::parse(
"PREFIX schema: <http://schema.org/> SELECT ?s WHERE { ?s a schema:Person }",
)
.unwrap();
let ranking = router.route(&query).unwrap();
assert!(!ranking.is_empty());
assert!(ranking.best().is_some());
}
#[test]
fn test_update_stats() {
let mut router = Router::new();
router.add_source(DataSource::new("test", "http://example.com/sparql"));
router.update_source_stats("test", 100, true, 50).unwrap();
let source = router.get_source("test").unwrap();
assert_eq!(source.stats.total_queries, 1);
assert_eq!(source.stats.successful_queries, 1);
}
#[cfg(feature = "cache")]
#[test]
fn test_cache_result() {
let mut router = Router::new();
router.cache_result(12345, "test result".to_string(), "source1".to_string());
let entry = router.get_cached_result(12345).unwrap();
assert_eq!(entry.result, "test result");
assert_eq!(entry.source_id, "source1");
}
#[cfg(feature = "cache")]
#[test]
fn test_cache_miss() {
let mut router = Router::new();
assert!(router.get_cached_result(99999).is_none());
}
#[cfg(feature = "cache")]
#[test]
fn test_cache_stats() {
let mut router = Router::new();
let _ = router.get_cached_result(1);
router.cache_result(1, "result".to_string(), "s1".to_string());
let _ = router.get_cached_result(1);
let stats = router.query_cache_stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[cfg(feature = "cache")]
#[test]
fn test_cache_enable_disable() {
let mut router = Router::new();
assert!(router.is_cache_enabled());
router.disable_cache();
assert!(!router.is_cache_enabled());
router.cache_result(1, "result".to_string(), "s1".to_string());
assert!(!router.is_result_cached(1));
router.enable_cache();
router.cache_result(1, "result".to_string(), "s1".to_string());
assert!(router.is_result_cached(1));
}
#[cfg(feature = "cache")]
#[test]
fn test_cache_clear() {
let mut router = Router::new();
router.cache_result(1, "r1".to_string(), "s1".to_string());
router.cache_result(2, "r2".to_string(), "s2".to_string());
assert_eq!(router.cached_query_count(), 2);
router.clear_query_cache();
assert_eq!(router.cached_query_count(), 0);
}
#[cfg(feature = "cache")]
#[test]
fn test_cache_invalidate() {
let mut router = Router::new();
router.cache_result(1, "r1".to_string(), "s1".to_string());
assert!(router.is_result_cached(1));
router.invalidate_cached_result(1);
assert!(!router.is_result_cached(1));
}
#[cfg(feature = "cache")]
#[test]
fn test_cache_context() {
let mut router = Router::new();
let ctx = crate::context::CombinedContext::new().with_timestamp(12345);
router.cache_context("provider1", ctx);
let cached = router.get_cached_context("provider1").unwrap();
assert_eq!(cached.timestamp, 12345);
}
}