use crate::t;
use futures::future::ok;
use langhub::LLMClient;
use langhub::llms::LLMResult;
use langhub::types::{ChatMessage, LangHubError};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
use hippox_drivers::{
DriverCallback, DriverContext, generate_driver_registry_table_json_str, get_driver_by_name,
has_driver, list_drivers_names,
};
#[derive(Clone)]
pub struct DriverScheduler {
llm: LLMClient,
}
impl fmt::Debug for DriverScheduler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DriverScheduler")
.field("llm", &"<LLMClient>")
.finish()
}
}
impl DriverScheduler {
pub fn new(llm: LLMClient) -> Self {
Self { llm }
}
pub fn get_drivers_prompt(&self) -> String {
let registry_json = generate_driver_registry_table_json_str();
format!("## Available Drivers (JSON Registry)\n{}", registry_json)
}
pub async fn select_driver(&self, user_input: &str) -> anyhow::Result<Option<String>> {
if list_drivers_names().is_empty() {
return Ok(None);
}
let drivers_prompt = self.get_drivers_prompt();
let select_prompt = format!(
"{}\n\nAvailable drivers:\n{}\n\nUser input: {}\n\nRespond with ONLY the driver name, or 'none' if no driver matches.\n",
t!("prompt.select_driver_header"),
drivers_prompt,
user_input
);
let result = self.llm.generate(&select_prompt).await?;
let response = result.text;
let driver_name = response.trim();
if driver_name == "none" || driver_name.is_empty() {
Ok(None)
} else if has_driver(driver_name) {
Ok(Some(driver_name.to_string()))
} else {
Ok(None)
}
}
pub async fn execute(
&self,
driver_name: &str,
user_input: &str,
conversation_history: &str,
driver_callback: Option<&dyn DriverCallback>,
driver_context: Option<&DriverContext>,
) -> anyhow::Result<String> {
println!("{}", t!("driver.executing", driver_name));
let driver = get_driver_by_name(driver_name)
.ok_or_else(|| anyhow::anyhow!("Driver not found: {}", driver_name))?;
let mut parameters = HashMap::new();
parameters.insert("input".to_string(), Value::String(user_input.to_string()));
driver
.execute(¶meters, driver_callback, driver_context)
.await
}
pub async fn execute_with_parameters(
&self,
driver_name: &str,
user_input: &str,
parameters: &HashMap<String, Value>,
conversation_history: &str,
driver_callback: Option<&dyn DriverCallback>,
driver_context: Option<&DriverContext>,
) -> anyhow::Result<String> {
println!("{}", t!("driver.executing", driver_name));
let driver = get_driver_by_name(driver_name)
.ok_or_else(|| anyhow::anyhow!("Driver not found: {}", driver_name))?;
driver
.execute(parameters, driver_callback, driver_context)
.await
}
pub async fn execute_with_messages(
&self,
driver_name: &str,
messages: Vec<ChatMessage>,
driver_callback: Option<&dyn DriverCallback>,
driver_context: Option<&DriverContext>,
) -> anyhow::Result<String> {
let driver = get_driver_by_name(driver_name)
.ok_or_else(|| anyhow::anyhow!("Driver not found: {}", driver_name))?;
let mut parameters = HashMap::new();
for msg in messages.iter().rev() {
if msg.role == "user" {
parameters.insert("input".to_string(), Value::String(msg.content.clone()));
break;
}
}
driver
.execute(¶meters, driver_callback, driver_context)
.await
}
pub async fn fallback_chat(&self, user_input: &str) -> anyhow::Result<String> {
let prompt = format!(
"{}\n\nYou are a helpful assistant. No specific driver matched the user's request.\n\nUser input: {}\n\nProvide a helpful, natural response to the user.\n",
t!("prompt.fallback"),
user_input
);
let result = self.llm.generate(&prompt).await?;
Ok(result.text)
}
pub async fn fallback_chat_with_history(
&self,
user_input: &str,
conversation_history: &str,
) -> anyhow::Result<String> {
let prompt = format!(
"{}\n\nYou are a helpful assistant. No specific driver matched the user's request.\n\nPrevious conversation:\n{}\n\nUser input: {}\n\nProvide a helpful, natural response considering the conversation history.\n",
t!("prompt.fallback"),
conversation_history,
user_input
);
let result = self.llm.generate(&prompt).await?;
Ok(result.text)
}
pub fn list_drivers(&self) -> String {
let drivers = list_drivers_names();
if drivers.is_empty() {
return t!("driver.no_drivers_available").to_string();
}
let mut result = String::new();
for name in drivers {
if let Some(driver) = get_driver_by_name(&name) {
let emoji = driver.category().icon();
result.push_str(&format!(
" {} - **{}**: {}\n",
emoji,
name,
driver.description()
));
}
}
result
}
pub fn get_driver_names(&self) -> Vec<String> {
list_drivers_names()
}
pub fn has_drivers(&self) -> bool {
!list_drivers_names().is_empty()
}
fn get_llm(&self) -> &LLMClient {
&self.llm
}
pub async fn chat_raw(
&self,
messages: Vec<ChatMessage>,
) -> anyhow::Result<LLMResult, LangHubError> {
self.llm.chat(messages).await
}
pub async fn generate_raw(&self, prompt: &str) -> anyhow::Result<LLMResult, LangHubError> {
self.llm.generate(prompt).await
}
pub async fn generate(&self, prompt: &str) -> anyhow::Result<String> {
let messages = vec![ChatMessage::user(prompt)];
self.chat(messages).await
}
pub async fn chat(&self, messages: Vec<ChatMessage>) -> anyhow::Result<String> {
let result = self.llm.chat(messages).await?;
Ok(result.text)
}
pub async fn generate_with_task(&self, prompt: &str, task_id: &str) -> anyhow::Result<String> {
let result = self.llm.generate(prompt).await?;
if let Some(usage) = result.extract_usage() {
if let Some(updater) = crate::tasks::get_state_updater(task_id).await {
updater
.add_token_usage_global(
usage.prompt_tokens as u64,
usage.completion_tokens as u64,
)
.await;
}
}
Ok(result.text)
}
pub async fn chat_with_task(
&self,
messages: Vec<ChatMessage>,
task_id: &str,
) -> anyhow::Result<String> {
let result = self.llm.chat(messages).await?;
if let Some(usage) = result.extract_usage() {
if let Some(updater) = crate::tasks::get_state_updater(task_id).await {
updater
.add_token_usage_global(
usage.prompt_tokens as u64,
usage.completion_tokens as u64,
)
.await;
}
}
Ok(result.text)
}
}
#[cfg(test)]
mod driver_scheduler_test {
use super::*;
use langhub::LLMClient;
use langhub::types::ModelProvider;
fn create_test_scheduler() -> DriverScheduler {
let llm = LLMClient::new_with_key(
ModelProvider::OpenAI,
Some("test-api-key".to_string()),
None,
)
.unwrap();
DriverScheduler::new(llm)
}
#[test]
fn test_list_drivers() {
let scheduler = create_test_scheduler();
let list = scheduler.list_drivers();
assert!(list.contains("helloworld"));
}
#[test]
fn test_get_driver_names() {
let scheduler = create_test_scheduler();
let names = scheduler.get_driver_names();
assert!(names.contains(&"helloworld".to_string()));
assert!(names.contains(&"calculator".to_string()));
assert!(names.contains(&"file_read".to_string()));
}
#[test]
fn test_has_drivers() {
let scheduler = create_test_scheduler();
assert!(scheduler.has_drivers());
}
#[test]
fn test_get_drivers_prompt() {
let scheduler = create_test_scheduler();
let prompt = scheduler.get_drivers_prompt();
assert!(prompt.contains("Available Drivers"));
assert!(prompt.contains("helloworld"));
assert!(prompt.contains("calculator"));
}
#[tokio::test]
async fn test_select_driver_with_trigger() {
let scheduler = create_test_scheduler();
let result = scheduler.select_driver("calculate 2+3").await;
assert!(result.is_ok());
}
}