use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::{
agent::{AgentState, Runtime},
language_models::{llm::LLM, GenerateResult},
schemas::{messages::Message, StructuredOutputStrategy},
tools::Tool,
};
use serde_json::Value;
pub struct ModelRequest {
pub messages: Vec<Message>,
pub tools: Vec<Arc<dyn Tool>>,
pub model: Option<Arc<dyn LLM>>,
pub response_format: Option<Box<dyn StructuredOutputStrategy>>,
pub state: Arc<Mutex<AgentState>>,
pub runtime: Option<Arc<Runtime>>,
pub metadata: HashMap<String, Value>,
}
impl ModelRequest {
pub fn new(
messages: Vec<Message>,
tools: Vec<Arc<dyn Tool>>,
state: Arc<Mutex<AgentState>>,
) -> Self {
Self {
messages,
tools,
model: None,
response_format: None,
state,
runtime: None,
metadata: HashMap::new(),
}
}
pub fn with_runtime(mut self, runtime: Arc<Runtime>) -> Self {
self.runtime = Some(runtime);
self
}
pub fn override_messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
pub fn override_tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
self.tools = tools;
self
}
pub fn override_model(mut self, model: Arc<dyn LLM>) -> Self {
self.model = Some(model);
self
}
pub fn override_response_format(mut self, format: Box<dyn StructuredOutputStrategy>) -> Self {
self.response_format = Some(format);
self
}
pub fn with_metadata(mut self, key: String, value: Value) -> Self {
self.metadata.insert(key, value);
self
}
pub fn runtime(&self) -> Option<&Arc<Runtime>> {
self.runtime.as_ref()
}
pub async fn state(&self) -> tokio::sync::MutexGuard<'_, AgentState> {
self.state.lock().await
}
pub fn with_messages_and_tools(
&self,
messages: Vec<Message>,
tools: Vec<Arc<dyn Tool>>,
) -> Self {
Self {
messages,
tools,
model: self
.model
.as_ref()
.map(|_| {
None
})
.flatten(),
response_format: None, state: Arc::clone(&self.state),
runtime: self.runtime.as_ref().map(|r| Arc::clone(r)),
metadata: self.metadata.clone(),
}
}
}
pub struct ModelResponse {
pub result: GenerateResult,
pub metadata: HashMap<String, Value>,
}
impl ModelResponse {
pub fn new(result: GenerateResult) -> Self {
Self {
result,
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: String, value: Value) -> Self {
self.metadata.insert(key, value);
self
}
}
impl From<GenerateResult> for ModelResponse {
fn from(result: GenerateResult) -> Self {
Self::new(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::{EmptyContext, InMemoryStore};
#[test]
fn test_model_request_creation() {
let state = Arc::new(Mutex::new(AgentState::new()));
let messages = vec![Message::new_human_message("Hello")];
let request = ModelRequest::new(messages, vec![], state);
assert_eq!(request.messages.len(), 1);
assert!(request.tools.is_empty());
assert!(request.model.is_none());
assert!(request.runtime().is_none());
}
#[test]
fn test_model_request_override() {
let state = Arc::new(Mutex::new(AgentState::new()));
let messages = vec![Message::new_human_message("Hello")];
let mut request = ModelRequest::new(messages, vec![], state);
let new_messages = vec![
Message::new_human_message("Hello"),
Message::new_human_message("World"),
];
request = request.override_messages(new_messages);
assert_eq!(request.messages.len(), 2);
}
#[tokio::test]
async fn test_model_request_with_runtime() {
let state = Arc::new(Mutex::new(AgentState::new()));
let context = Arc::new(EmptyContext);
let store = Arc::new(InMemoryStore::new());
let runtime = Arc::new(Runtime::new(context, store));
let messages = vec![Message::new_human_message("Hello")];
let request = ModelRequest::new(messages, vec![], state).with_runtime(runtime);
assert!(request.runtime().is_some());
}
}