use std::collections::HashMap;
use std::sync::Arc;
use cognis_core::embeddings::Embeddings;
use cognis_core::error::{CognisError, Result};
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::{HumanMessage, Message};
use cognis_core::vectorstores::base::cosine_similarity;
#[derive(Debug, Clone)]
pub struct Route {
pub name: String,
pub description: String,
pub prompt_template: Option<String>,
}
impl Route {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
prompt_template: None,
}
}
pub fn with_prompt_template(mut self, template: impl Into<String>) -> Self {
self.prompt_template = Some(template.into());
self
}
}
pub struct SemanticRouter {
embeddings: Arc<dyn Embeddings>,
routes: Vec<Route>,
route_embeddings: Vec<Vec<f32>>,
pub default_route: Option<String>,
}
impl SemanticRouter {
pub async fn new(embeddings: Arc<dyn Embeddings>, routes: Vec<Route>) -> Result<Self> {
let descriptions: Vec<String> = routes.iter().map(|r| r.description.clone()).collect();
let route_embeddings = embeddings.embed_documents(descriptions).await?;
Ok(Self {
embeddings,
routes,
route_embeddings,
default_route: None,
})
}
pub fn with_default_route(mut self, name: impl Into<String>) -> Self {
self.default_route = Some(name.into());
self
}
pub async fn route(&self, query: &str) -> Result<&Route> {
let (route, _score) = self.route_with_score(query).await?;
Ok(route)
}
pub async fn route_with_score(&self, query: &str) -> Result<(&Route, f32)> {
if self.routes.is_empty() {
return Err(CognisError::Other("No routes defined".into()));
}
let query_embedding = self.embeddings.embed_query(query).await?;
let mut best_idx = 0;
let mut best_score = f32::NEG_INFINITY;
for (i, route_emb) in self.route_embeddings.iter().enumerate() {
let score = cosine_similarity(&query_embedding, route_emb);
if score > best_score {
best_score = score;
best_idx = i;
}
}
Ok((&self.routes[best_idx], best_score))
}
pub fn routes(&self) -> &[Route] {
&self.routes
}
pub fn find_default_route(&self) -> Option<&Route> {
self.default_route
.as_ref()
.and_then(|name| self.routes.iter().find(|r| r.name == *name))
}
}
#[derive(Debug, Clone)]
pub struct RouterResult {
pub answer: String,
pub route_name: String,
pub confidence: f32,
}
pub struct RouterChain {
router: SemanticRouter,
llm: Arc<dyn BaseChatModel>,
route_prompts: HashMap<String, String>,
default_prompt: String,
}
impl RouterChain {
pub fn new(router: SemanticRouter, llm: Arc<dyn BaseChatModel>) -> Self {
Self {
router,
llm,
route_prompts: HashMap::new(),
default_prompt: "Answer the following question: {query}".to_string(),
}
}
pub fn with_route_prompt(
mut self,
route_name: impl Into<String>,
template: impl Into<String>,
) -> Self {
self.route_prompts
.insert(route_name.into(), template.into());
self
}
pub fn with_default_prompt(mut self, template: impl Into<String>) -> Self {
self.default_prompt = template.into();
self
}
pub async fn call(&self, query: &str) -> Result<RouterResult> {
let (route, confidence) = self.router.route_with_score(query).await?;
let route_name = route.name.clone();
let template = self
.route_prompts
.get(&route_name)
.or(route.prompt_template.as_ref())
.unwrap_or(&self.default_prompt);
let formatted = template.replace("{query}", query);
let messages = vec![Message::Human(HumanMessage::new(&formatted))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
let answer = ai_msg.base.content.text();
Ok(RouterResult {
answer,
route_name,
confidence,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::embeddings_fake::DeterministicFakeEmbedding;
use cognis_core::language_models::fake::FakeListChatModel;
fn fake_embeddings() -> Arc<dyn Embeddings> {
Arc::new(DeterministicFakeEmbedding::new(128))
}
fn fake_model(responses: Vec<&str>) -> Arc<dyn BaseChatModel> {
Arc::new(FakeListChatModel::new(
responses.into_iter().map(String::from).collect(),
))
}
#[tokio::test]
async fn test_route_to_correct_route_based_on_similarity() {
let routes = vec![
Route::new("math", "mathematics calculations arithmetic numbers"),
Route::new("history", "historical events dates civilizations wars"),
Route::new("science", "physics chemistry biology experiments"),
];
let router = SemanticRouter::new(fake_embeddings(), routes)
.await
.unwrap();
let route = router
.route("mathematics calculations arithmetic numbers")
.await
.unwrap();
assert_eq!(route.name, "math");
let route = router
.route("historical events dates civilizations wars")
.await
.unwrap();
assert_eq!(route.name, "history");
let route = router
.route("physics chemistry biology experiments")
.await
.unwrap();
assert_eq!(route.name, "science");
}
#[tokio::test]
async fn test_route_with_score_returns_valid_confidence() {
let routes = vec![
Route::new("greeting", "hello hi greetings welcome"),
Route::new("farewell", "goodbye bye see you later"),
];
let router = SemanticRouter::new(fake_embeddings(), routes)
.await
.unwrap();
let (route, score) = router
.route_with_score("hello hi greetings welcome")
.await
.unwrap();
assert_eq!(route.name, "greeting");
assert!(
(score - 1.0).abs() < 1e-5,
"Exact match should have score ~1.0, got {score}"
);
let (_route, score) = router.route_with_score("some random query").await.unwrap();
assert!(
score >= -1.0 && score <= 1.0,
"Score should be in [-1,1], got {score}"
);
}
#[tokio::test]
async fn test_router_chain_uses_correct_prompt_for_matched_route() {
let routes = vec![
Route::new("math", "mathematics calculations arithmetic numbers"),
Route::new("history", "historical events dates civilizations wars"),
];
let router = SemanticRouter::new(fake_embeddings(), routes)
.await
.unwrap();
let chain = RouterChain::new(router, fake_model(vec!["42"]))
.with_route_prompt("math", "Solve this math problem: {query}")
.with_route_prompt("history", "Answer this history question: {query}");
let result = chain
.call("mathematics calculations arithmetic numbers")
.await
.unwrap();
assert_eq!(result.route_name, "math");
assert_eq!(result.answer, "42");
}
#[tokio::test]
async fn test_default_route_when_no_good_match() {
let routes = vec![
Route::new("default", "general questions anything else"),
Route::new("cooking", "recipes food ingredients kitchen"),
];
let router = SemanticRouter::new(fake_embeddings(), routes)
.await
.unwrap()
.with_default_route("default");
let default = router.find_default_route();
assert!(default.is_some());
assert_eq!(default.unwrap().name, "default");
let chain = RouterChain::new(router, fake_model(vec!["I can help with that"]));
let result = chain.call("general questions anything else").await.unwrap();
assert_eq!(result.route_name, "default");
assert_eq!(result.answer, "I can help with that");
}
#[tokio::test]
async fn test_multiple_routes_with_different_descriptions() {
let routes = vec![
Route::new("weather", "weather forecast temperature rain sunny cloudy"),
Route::new("sports", "football basketball soccer tennis athletes game"),
Route::new("music", "songs albums artists concerts genres rhythm"),
Route::new("travel", "flights hotels destinations tourism vacation"),
Route::new("tech", "programming software computers code algorithms"),
];
let router = SemanticRouter::new(fake_embeddings(), routes)
.await
.unwrap();
let route = router
.route("weather forecast temperature rain sunny cloudy")
.await
.unwrap();
assert_eq!(route.name, "weather");
let route = router
.route("football basketball soccer tennis athletes game")
.await
.unwrap();
assert_eq!(route.name, "sports");
let route = router
.route("songs albums artists concerts genres rhythm")
.await
.unwrap();
assert_eq!(route.name, "music");
let route = router
.route("flights hotels destinations tourism vacation")
.await
.unwrap();
assert_eq!(route.name, "travel");
let route = router
.route("programming software computers code algorithms")
.await
.unwrap();
assert_eq!(route.name, "tech");
}
#[tokio::test]
async fn test_router_result_contains_correct_metadata() {
let routes = vec![Route::new(
"support",
"customer support help troubleshooting issues",
)];
let router = SemanticRouter::new(fake_embeddings(), routes)
.await
.unwrap();
let chain = RouterChain::new(router, fake_model(vec!["Let me help you"]))
.with_route_prompt("support", "Help the customer: {query}");
let result = chain
.call("customer support help troubleshooting issues")
.await
.unwrap();
assert_eq!(result.route_name, "support");
assert_eq!(result.answer, "Let me help you");
assert!(
result.confidence > 0.99,
"Exact match confidence should be near 1.0, got {}",
result.confidence
);
}
#[tokio::test]
async fn test_route_with_custom_prompt_template_on_route() {
let routes = vec![
Route::new("translate", "translation language convert words")
.with_prompt_template("Translate the following: {query}"),
];
let router = SemanticRouter::new(fake_embeddings(), routes)
.await
.unwrap();
let chain = RouterChain::new(router, fake_model(vec!["Translated text"]));
let result = chain
.call("translation language convert words")
.await
.unwrap();
assert_eq!(result.route_name, "translate");
assert_eq!(result.answer, "Translated text");
}
}