mod complexity;
mod entity_enhancer;
mod relevance_scorer;
mod retrieval_classifier;
mod router;
pub mod strategies;
mod strategy_selector;
mod summarizer;
mod validator;
pub use complexity::{ComplexityResult, ComplexityScorer, ComplexityScorerBuilder};
pub use entity_enhancer::{
EnhancedEntity, EnhancedRelationship, EnhancementResult, EntityEnhancer, EntityEnhancerBuilder,
RelationType, SemanticEntityType,
};
pub use relevance_scorer::{RelevanceResult, RelevanceScorer, RelevanceScorerBuilder};
pub use retrieval_classifier::{
ClassificationResult, RetrievalClassifier, RetrievalClassifierBuilder,
RetrievalNeed as LocalRetrievalNeed,
};
pub use router::{LocalRouter, LocalRouterBuilder, RouteResult};
pub use strategies::{
ChainOfThoughtStrategy, ReActStrategy, ReasoningStrategy, ReflexionStrategy, StrategyPreset,
StrategyStep, TreeOfThoughtsStrategy,
};
pub use strategy_selector::{
RecommendedStrategy, StrategyResult, StrategySelector, StrategySelectorBuilder, TaskType,
};
pub use summarizer::{
ExtractedFact, FactCategory, LocalSummarizer, LocalSummarizerBuilder, SummarizationResult,
};
pub use validator::{LocalValidator, LocalValidatorBuilder, ValidationResult};
use std::time::Instant;
use tracing::{info, warn};
#[derive(Clone, Debug)]
pub struct LocalInferenceConfig {
pub routing_enabled: bool,
pub validation_enabled: bool,
pub complexity_enabled: bool,
pub summarization_enabled: bool,
pub retrieval_gating_enabled: bool,
pub relevance_scoring_enabled: bool,
pub strategy_selection_enabled: bool,
pub entity_enhancement_enabled: bool,
pub routing_model: Option<String>,
pub validation_model: Option<String>,
pub complexity_model: Option<String>,
pub summarization_model: Option<String>,
pub retrieval_model: Option<String>,
pub relevance_model: Option<String>,
pub strategy_model: Option<String>,
pub entity_model: Option<String>,
pub log_inference: bool,
}
impl Default for LocalInferenceConfig {
fn default() -> Self {
Self {
routing_enabled: false,
validation_enabled: false,
complexity_enabled: false,
summarization_enabled: false,
retrieval_gating_enabled: false,
relevance_scoring_enabled: false,
strategy_selection_enabled: false,
entity_enhancement_enabled: false,
routing_model: Some("lfm2-350m".to_string()),
validation_model: Some("lfm2-350m".to_string()),
complexity_model: Some("lfm2-350m".to_string()),
summarization_model: Some("lfm2-1.2b".to_string()), retrieval_model: Some("lfm2-350m".to_string()),
relevance_model: Some("lfm2-350m".to_string()),
strategy_model: Some("lfm2-1.2b".to_string()), entity_model: Some("lfm2-350m".to_string()),
log_inference: true,
}
}
}
impl LocalInferenceConfig {
pub fn tier1_enabled() -> Self {
Self {
routing_enabled: true,
validation_enabled: true,
complexity_enabled: true,
..Default::default()
}
}
pub fn tier2_enabled() -> Self {
Self {
summarization_enabled: true,
retrieval_gating_enabled: true,
relevance_scoring_enabled: true,
strategy_selection_enabled: true,
entity_enhancement_enabled: true,
..Default::default()
}
}
pub fn all_enabled() -> Self {
Self {
routing_enabled: true,
validation_enabled: true,
complexity_enabled: true,
summarization_enabled: true,
retrieval_gating_enabled: true,
relevance_scoring_enabled: true,
strategy_selection_enabled: true,
entity_enhancement_enabled: true,
..Default::default()
}
}
pub fn routing_only() -> Self {
Self {
routing_enabled: true,
..Default::default()
}
}
pub fn validation_only() -> Self {
Self {
validation_enabled: true,
..Default::default()
}
}
pub fn summarization_only() -> Self {
Self {
summarization_enabled: true,
..Default::default()
}
}
}
pub fn log_inference(task: &str, model: &str, latency_ms: u64, success: bool) {
if success {
info!(
target: "local_llm",
task = task,
model = model,
latency_ms = latency_ms,
"Local inference completed"
);
} else {
warn!(
target: "local_llm",
task = task,
model = model,
latency_ms = latency_ms,
"Local inference failed, falling back to pattern-based"
);
}
}
pub struct InferenceTimer {
start: Instant,
task: String,
model: String,
}
impl InferenceTimer {
pub fn new(task: impl Into<String>, model: impl Into<String>) -> Self {
Self {
start: Instant::now(),
task: task.into(),
model: model.into(),
}
}
pub fn finish(self, success: bool) {
let latency_ms = self.start.elapsed().as_millis() as u64;
log_inference(&self.task, &self.model, latency_ms, success);
}
pub fn elapsed_ms(&self) -> u64 {
self.start.elapsed().as_millis() as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = LocalInferenceConfig::default();
assert!(!config.routing_enabled);
assert!(!config.validation_enabled);
assert!(!config.complexity_enabled);
assert!(!config.summarization_enabled);
assert!(!config.retrieval_gating_enabled);
assert!(!config.relevance_scoring_enabled);
}
#[test]
fn test_config_tier1_enabled() {
let config = LocalInferenceConfig::tier1_enabled();
assert!(config.routing_enabled);
assert!(config.validation_enabled);
assert!(config.complexity_enabled);
assert!(!config.summarization_enabled);
}
#[test]
fn test_config_tier2_enabled() {
let config = LocalInferenceConfig::tier2_enabled();
assert!(!config.routing_enabled);
assert!(config.summarization_enabled);
assert!(config.retrieval_gating_enabled);
assert!(config.relevance_scoring_enabled);
assert!(config.strategy_selection_enabled);
assert!(config.entity_enhancement_enabled);
}
#[test]
fn test_config_all_enabled() {
let config = LocalInferenceConfig::all_enabled();
assert!(config.routing_enabled);
assert!(config.validation_enabled);
assert!(config.complexity_enabled);
assert!(config.summarization_enabled);
assert!(config.retrieval_gating_enabled);
assert!(config.relevance_scoring_enabled);
assert!(config.strategy_selection_enabled);
assert!(config.entity_enhancement_enabled);
}
#[test]
fn test_config_summarization_only() {
let config = LocalInferenceConfig::summarization_only();
assert!(!config.routing_enabled);
assert!(config.summarization_enabled);
assert_eq!(config.summarization_model, Some("lfm2-1.2b".to_string()));
}
#[test]
fn test_inference_timer() {
let timer = InferenceTimer::new("test_task", "test_model");
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(timer.elapsed_ms() >= 10);
}
}