use crate::providers::{Provider, StreamResult};
use crate::role_switcher::{
DEFAULT_LONG_CONTEXT_THRESHOLD, RoleSignals, decide_role, resolve_role_to_model,
};
use crate::roles::RoleRegistry;
use crate::{Context, Model, StreamOptions, ThinkingLevel};
use parking_lot::RwLock;
use std::pin::Pin;
use std::sync::Arc;
pub struct RoleRoutingProvider {
default: Arc<dyn Provider>,
registry: Arc<RwLock<RoleRegistry>>,
}
impl RoleRoutingProvider {
#[must_use]
pub fn new(default: Arc<dyn Provider>, registry: Arc<RwLock<RoleRegistry>>) -> Self {
Self { default, registry }
}
pub fn update_registry(&self, f: impl FnOnce(&mut RoleRegistry)) {
f(&mut self.registry.write());
}
#[must_use]
pub fn signals_from_request(
context: &Context,
options: &Option<StreamOptions>,
) -> RoleSignals<'static> {
let thinking_enabled = options
.as_ref()
.and_then(|o| o.thinking_level)
.is_some_and(|level| level != ThinkingLevel::Off);
RoleSignals {
explicit_override: None,
current_tool: None,
thinking_enabled,
estimated_tokens: estimate_tokens(context),
long_context_threshold: DEFAULT_LONG_CONTEXT_THRESHOLD,
is_trivial: is_trivial(context),
}
}
}
fn estimate_tokens(context: &Context) -> usize {
context
.messages
.iter()
.map(|m| m.text_content().unwrap_or_default().len() / 4)
.sum()
}
fn is_trivial(context: &Context) -> bool {
let last = context
.messages
.last()
.map(|m| m.text_content().unwrap_or_default())
.unwrap_or_default();
last.len() < 40 && !last.contains("```")
}
impl Provider for RoleRoutingProvider {
fn stream<'a>(
&'a self,
model: &'a Model,
context: &'a Context,
options: Option<StreamOptions>,
) -> Pin<Box<dyn Future<Output = StreamResult> + Send + 'a>> {
let default = Arc::clone(&self.default);
let registry = Arc::clone(&self.registry);
Box::pin(async move {
let role_model = {
let reg = registry.read();
if reg.is_empty() {
None
} else {
let signals = Self::signals_from_request(context, &options);
let role = decide_role(&signals);
resolve_role_to_model(role, ®)
.filter(|m| m.provider != model.provider || m.id != model.id)
}
};
let Some(role_model) = role_model else {
return default.stream(model, context, options).await;
};
match crate::get_provider_arc(&role_model.provider) {
Some(provider) => {
match provider.stream(&role_model, context, options.clone()).await {
Ok(stream) => Ok(stream),
Err(err) => {
tracing::warn!(
target: "role-router",
error = %err,
"role-model provider failed; falling back to default model"
);
default.stream(model, context, options).await
}
}
}
None => default.stream(model, context, options).await,
}
})
}
fn name(&self) -> &str {
self.default.name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Context, Message, StreamOptions, ThinkingLevel};
fn ctx_with_last(text: &str) -> Context {
Context {
messages: vec![Message::user(text)],
..Context::default()
}
}
#[test]
fn signals_thinking_from_options() {
let ctx = ctx_with_last("explain");
let opts = StreamOptions::default().thinking_level(ThinkingLevel::High);
let s = RoleRoutingProvider::signals_from_request(&ctx, &Some(opts));
assert!(s.thinking_enabled);
}
#[test]
fn signals_thinking_off_is_disabled() {
let ctx = ctx_with_last("explain");
let opts = StreamOptions::default().thinking_level(ThinkingLevel::Off);
let s = RoleRoutingProvider::signals_from_request(&ctx, &Some(opts));
assert!(!s.thinking_enabled);
}
#[test]
fn signals_no_options_is_not_thinking() {
let ctx = ctx_with_last("explain");
let s = RoleRoutingProvider::signals_from_request(&ctx, &None);
assert!(!s.thinking_enabled);
}
#[test]
fn signals_long_context_exceeds_threshold() {
let big = "x".repeat(60_000 * 4 + 100);
let ctx = ctx_with_last(&big);
let s = RoleRoutingProvider::signals_from_request(&ctx, &None);
assert!(
s.estimated_tokens > DEFAULT_LONG_CONTEXT_THRESHOLD,
"estimated {} should exceed {}",
s.estimated_tokens,
DEFAULT_LONG_CONTEXT_THRESHOLD
);
}
#[test]
fn signals_short_message_is_trivial() {
let ctx = ctx_with_last("hi");
let s = RoleRoutingProvider::signals_from_request(&ctx, &None);
assert!(s.is_trivial);
}
#[test]
fn signals_code_fence_is_not_trivial() {
let ctx = ctx_with_last("```\ncode\n```");
let s = RoleRoutingProvider::signals_from_request(&ctx, &None);
assert!(!s.is_trivial);
}
#[test]
fn estimate_tokens_scales_with_text() {
let small = ctx_with_last("hi");
let large = ctx_with_last(&"x".repeat(4_000));
assert!(estimate_tokens(&large) > estimate_tokens(&small));
}
#[test]
fn empty_registry_passes_through() {
let r = RoleRegistry::new();
assert!(r.is_empty());
}
}