use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::budget::BudgetTracker;
use crate::error::{RlmError, RlmResult};
use crate::session::SessionManager;
use crate::types::{LlmQuery, SessionId};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRequest {
pub prompt: String,
pub model: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
}
impl From<QueryRequest> for LlmQuery {
fn from(req: QueryRequest) -> Self {
LlmQuery {
prompt: req.prompt,
model: req.model,
temperature: req.temperature,
max_tokens: req.max_tokens,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResponse {
pub response: String,
pub tokens_used: u64,
pub time_ms: u64,
pub tokens_remaining: u64,
pub time_remaining_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchedQueryRequest {
pub queries: Vec<QueryRequest>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchedQueryResponse {
pub responses: Vec<Result<QueryResponse, String>>,
pub total_tokens_used: u64,
pub total_time_ms: u64,
}
#[derive(Debug, Clone)]
pub struct LlmBridgeConfig {
pub port: u16,
pub bind_addr: String,
pub max_batch_size: usize,
pub request_timeout_ms: u64,
}
impl Default for LlmBridgeConfig {
fn default() -> Self {
Self {
port: 8080,
bind_addr: "127.0.0.1".to_string(),
max_batch_size: 10,
request_timeout_ms: 30_000,
}
}
}
pub struct LlmBridge {
config: LlmBridgeConfig,
session_manager: Arc<SessionManager>,
budget_trackers: dashmap::DashMap<SessionId, Arc<BudgetTracker>>,
#[cfg(feature = "llm")]
llm_client: Option<Arc<dyn terraphim_service::llm::LlmClient>>,
}
impl LlmBridge {
pub fn new(config: LlmBridgeConfig, session_manager: Arc<SessionManager>) -> Self {
Self {
config,
session_manager,
budget_trackers: dashmap::DashMap::new(),
#[cfg(feature = "llm")]
llm_client: None,
}
}
#[cfg(feature = "llm")]
pub fn with_llm_client(
config: LlmBridgeConfig,
session_manager: Arc<SessionManager>,
client: Arc<dyn terraphim_service::llm::LlmClient>,
) -> Self {
Self {
config,
session_manager,
budget_trackers: dashmap::DashMap::new(),
llm_client: Some(client),
}
}
pub fn validate_token(&self, token: &str) -> RlmResult<SessionId> {
let session_id =
SessionId::from_string(token).map_err(|_| RlmError::InvalidSessionToken {
token: token.to_string(),
})?;
self.session_manager.validate_session(&session_id)?;
Ok(session_id)
}
pub fn get_budget_tracker(&self, session_id: &SessionId) -> Arc<BudgetTracker> {
self.budget_trackers
.entry(*session_id)
.or_insert_with(|| {
let session = self
.session_manager
.get_session(session_id)
.expect("session should exist after validation");
Arc::new(BudgetTracker::from_status(&session.budget_status))
})
.clone()
}
pub async fn query(
&self,
session_id: &SessionId,
request: QueryRequest,
) -> RlmResult<QueryResponse> {
let budget = self.get_budget_tracker(session_id);
budget.check_all()?;
budget.push_recursion()?;
let start = std::time::Instant::now();
#[cfg(feature = "llm")]
let response_text = match &self.llm_client {
Some(client) => {
let chat_opts = terraphim_service::llm::ChatOptions {
max_tokens: request.max_tokens,
temperature: request.temperature,
};
let messages = vec![serde_json::json!({
"role": "user",
"content": request.prompt
})];
client
.chat_completion(messages, chat_opts)
.await
.map_err(|e| RlmError::LlmCallFailed {
message: e.to_string(),
})?
}
None => {
return Err(RlmError::LlmNotConfigured);
}
};
#[cfg(not(feature = "llm"))]
{
let _request = request;
return Err(RlmError::LlmNotConfigured);
}
let estimated_tokens = (request.prompt.len() / 4 + response_text.len() / 4) as u64;
budget.add_tokens(estimated_tokens)?;
let time_ms = start.elapsed().as_millis() as u64;
budget.pop_recursion();
Ok(QueryResponse {
response: response_text,
tokens_used: estimated_tokens,
time_ms,
tokens_remaining: budget.tokens_remaining(),
time_remaining_ms: budget.time_remaining_ms(),
})
}
pub async fn query_batched(
&self,
session_id: &SessionId,
request: BatchedQueryRequest,
) -> RlmResult<BatchedQueryResponse> {
let budget = self.get_budget_tracker(session_id);
budget.check_all()?;
if request.queries.len() > self.config.max_batch_size {
return Err(RlmError::BatchSizeTooLarge {
size: request.queries.len(),
max: self.config.max_batch_size,
});
}
let start = std::time::Instant::now();
let futures: Vec<_> = request
.queries
.into_iter()
.map(|query| {
let session_id = *session_id;
let this = self;
async move { this.query(&session_id, query).await }
})
.collect();
let results = futures::future::join_all(futures).await;
let mut total_tokens = 0u64;
let responses: Vec<Result<QueryResponse, String>> = results
.into_iter()
.map(|r| match r {
Ok(resp) => {
total_tokens += resp.tokens_used;
Ok(resp)
}
Err(e) => Err(e.to_string()),
})
.collect();
let total_time_ms = start.elapsed().as_millis() as u64;
Ok(BatchedQueryResponse {
responses,
total_tokens_used: total_tokens,
total_time_ms,
})
}
pub fn bind_addr(&self) -> String {
format!("{}:{}", self.config.bind_addr, self.config.port)
}
pub fn config(&self) -> &LlmBridgeConfig {
&self.config
}
}
impl BudgetTracker {
pub fn from_status(status: &crate::types::BudgetStatus) -> Self {
use crate::config::RlmConfig;
let config = RlmConfig {
token_budget: status.token_budget,
time_budget_ms: status.time_budget_ms,
max_recursion_depth: status.max_recursion_depth,
..Default::default()
};
let tracker = Self::new(&config);
if status.tokens_used > 0 {
tracker.add_tokens(status.tokens_used).ok();
}
for _ in 0..status.current_recursion_depth {
tracker.push_recursion().ok();
}
tracker
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::RlmConfig;
fn create_test_bridge() -> (LlmBridge, SessionId) {
let config = RlmConfig::default();
let session_manager = Arc::new(SessionManager::new(config));
let session = session_manager.create_session().unwrap();
#[cfg(feature = "llm")]
let bridge = {
use ahash::AHashMap;
use serde_json::Value;
use terraphim_config::Role;
use terraphim_types::RelevanceFunction;
if let Ok(api_key) = std::env::var("OPENROUTER_API_KEY") {
let mut extra = AHashMap::new();
extra.insert(
"llm_provider".to_string(),
Value::String("genai".to_string()),
);
extra.insert(
"llm_model".to_string(),
Value::String("mistralai/mistral-7b-instruct:free".to_string()),
);
let role = Role {
shortname: None,
name: "test".into(),
relevance_function: RelevanceFunction::TitleScorer,
terraphim_it: false,
theme: "default".to_string(),
kg: None,
haystacks: vec![],
llm_enabled: true,
llm_api_key: Some(api_key),
llm_model: Some("mistralai/mistral-7b-instruct:free".to_string()),
llm_auto_summarize: false,
llm_chat_enabled: false,
llm_chat_system_prompt: None,
llm_chat_model: None,
llm_context_window: None,
extra,
llm_router_enabled: false,
llm_router_config: None,
};
if let Some(client) = terraphim_service::llm::build_llm_from_role(&role) {
LlmBridge::with_llm_client(LlmBridgeConfig::default(), session_manager, client)
} else {
LlmBridge::new(LlmBridgeConfig::default(), session_manager)
}
} else {
LlmBridge::new(LlmBridgeConfig::default(), session_manager)
}
};
#[cfg(not(feature = "llm"))]
let bridge = LlmBridge::new(LlmBridgeConfig::default(), session_manager);
(bridge, session.id)
}
#[test]
fn test_token_validation() {
let (bridge, session_id) = create_test_bridge();
let result = bridge.validate_token(&session_id.to_string());
assert!(result.is_ok());
assert_eq!(result.unwrap(), session_id);
let result = bridge.validate_token("invalid-token");
assert!(result.is_err());
}
#[tokio::test]
async fn test_single_query() {
if std::env::var("OPENROUTER_API_KEY").is_err() {
eprintln!("Skipping test_single_query: OPENROUTER_API_KEY not set");
return;
}
let (bridge, session_id) = create_test_bridge();
let request = QueryRequest {
prompt: "Hello, world!".to_string(),
model: None,
temperature: None,
max_tokens: None,
};
let result = bridge.query(&session_id, request).await;
assert!(result.is_ok(), "query failed: {:?}", result.err());
let response = result.unwrap();
assert!(!response.response.is_empty());
assert!(response.tokens_used > 0);
}
#[tokio::test]
async fn test_batched_query() {
if std::env::var("OPENROUTER_API_KEY").is_err() {
eprintln!("Skipping test_batched_query: OPENROUTER_API_KEY not set");
return;
}
let (bridge, session_id) = create_test_bridge();
let request = BatchedQueryRequest {
queries: vec![
QueryRequest {
prompt: "Query 1".to_string(),
model: None,
temperature: None,
max_tokens: None,
},
QueryRequest {
prompt: "Query 2".to_string(),
model: None,
temperature: None,
max_tokens: None,
},
],
};
let result = bridge.query_batched(&session_id, request).await;
assert!(result.is_ok(), "batched query failed: {:?}", result.err());
let response = result.unwrap();
assert_eq!(response.responses.len(), 2);
assert!(response.total_tokens_used > 0);
}
#[tokio::test]
async fn test_batch_size_limit() {
let config = LlmBridgeConfig {
max_batch_size: 2,
..Default::default()
};
let rlm_config = RlmConfig::default();
let session_manager = Arc::new(SessionManager::new(rlm_config));
let session = session_manager.create_session().unwrap();
let bridge = LlmBridge::new(config, session_manager);
let request = BatchedQueryRequest {
queries: vec![
QueryRequest {
prompt: "Query 1".to_string(),
model: None,
temperature: None,
max_tokens: None,
},
QueryRequest {
prompt: "Query 2".to_string(),
model: None,
temperature: None,
max_tokens: None,
},
QueryRequest {
prompt: "Query 3".to_string(),
model: None,
temperature: None,
max_tokens: None,
},
],
};
let result = bridge.query_batched(&session.id, request).await;
assert!(matches!(result, Err(RlmError::BatchSizeTooLarge { .. })));
}
#[test]
fn test_budget_tracker_from_status() {
let status = crate::types::BudgetStatus {
token_budget: 1000,
tokens_used: 100,
time_budget_ms: 60_000,
time_used_ms: 0,
max_recursion_depth: 5,
current_recursion_depth: 2,
};
let tracker = BudgetTracker::from_status(&status);
assert_eq!(tracker.tokens_used(), 100);
assert_eq!(tracker.current_depth(), 2);
}
}