use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::Message;
use cognis_core::tools::base::BaseTool;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum JumpTo {
Tools,
Model,
End,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AgentState {
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub structured_response: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub jump_to: Option<JumpTo>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub extra: HashMap<String, Value>,
}
impl AgentState {
pub fn new(messages: Vec<Message>) -> Self {
Self {
messages,
..Default::default()
}
}
pub fn get_extra(&self, key: &str) -> Option<&Value> {
self.extra.get(key)
}
pub fn set_extra(&mut self, key: impl Into<String>, value: Value) {
self.extra.insert(key.into(), value);
}
pub fn apply_updates(&mut self, updates: HashMap<String, Value>) {
for (key, value) in updates {
match key.as_str() {
"jump_to" => {
self.jump_to = serde_json::from_value(value).ok();
}
"structured_response" => {
self.structured_response = Some(value);
}
_ => {
self.extra.insert(key, value);
}
}
}
}
}
pub struct ModelRequest {
pub model: Arc<dyn BaseChatModel>,
pub messages: Vec<Message>,
pub system_message: Option<Message>,
pub tool_choice: Option<Value>,
pub tools: Vec<Arc<dyn BaseTool>>,
pub response_format: Option<Value>,
pub state: AgentState,
pub model_settings: HashMap<String, Value>,
}
impl ModelRequest {
pub fn new(model: Arc<dyn BaseChatModel>, messages: Vec<Message>) -> Self {
Self {
model,
messages,
system_message: None,
tool_choice: None,
tools: Vec::new(),
response_format: None,
state: AgentState::default(),
model_settings: HashMap::new(),
}
}
pub fn system_prompt(&self) -> Option<String> {
self.system_message.as_ref().map(|m| m.content().text())
}
pub fn with_system_message(mut self, msg: Message) -> Self {
self.system_message = Some(msg);
self
}
pub fn with_tools(mut self, tools: Vec<Arc<dyn BaseTool>>) -> Self {
self.tools = tools;
self
}
pub fn with_state(mut self, state: AgentState) -> Self {
self.state = state;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelResponse {
pub result: Vec<Message>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub structured_response: Option<Value>,
}
impl ModelResponse {
pub fn new(result: Vec<Message>) -> Self {
Self {
result,
structured_response: None,
}
}
pub fn with_structured_response(mut self, response: Value) -> Self {
self.structured_response = Some(response);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtendedModelResponse {
pub model_response: ModelResponse,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub state_update: Option<HashMap<String, Value>>,
}
pub enum ModelCallResult {
Response(ModelResponse),
Message(Box<Message>),
Extended(ExtendedModelResponse),
}
impl From<ModelResponse> for ModelCallResult {
fn from(r: ModelResponse) -> Self {
ModelCallResult::Response(r)
}
}
impl From<Message> for ModelCallResult {
fn from(m: Message) -> Self {
ModelCallResult::Message(Box::new(m))
}
}
impl From<ExtendedModelResponse> for ModelCallResult {
fn from(e: ExtendedModelResponse) -> Self {
ModelCallResult::Extended(e)
}
}
pub fn normalize_model_call_result(result: ModelCallResult) -> ModelResponse {
match result {
ModelCallResult::Response(r) => r,
ModelCallResult::Message(m) => ModelResponse::new(vec![*m]),
ModelCallResult::Extended(e) => e.model_response,
}
}
pub type ModelHandler = Box<dyn Fn(&ModelRequest) -> Result<ModelResponse> + Send + Sync>;
pub type AsyncModelHandler = Box<
dyn Fn(
&ModelRequest,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<ModelResponse>> + Send + '_>,
> + Send
+ Sync,
>;
#[deprecated(
since = "0.2.0",
note = "Use cognisagent::middleware::Middleware instead. AgentMiddleware is not invoked by the executor and will be removed in a future release."
)]
#[async_trait]
pub trait AgentMiddleware: Send + Sync {
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
fn tools(&self) -> Vec<Arc<dyn BaseTool>> {
Vec::new()
}
async fn before_agent(&self, _state: &AgentState) -> Result<Option<HashMap<String, Value>>> {
Ok(None)
}
async fn after_agent(&self, _state: &AgentState) -> Result<Option<HashMap<String, Value>>> {
Ok(None)
}
async fn before_model(&self, _state: &AgentState) -> Result<Option<HashMap<String, Value>>> {
Ok(None)
}
async fn after_model(&self, _state: &AgentState) -> Result<Option<HashMap<String, Value>>> {
Ok(None)
}
async fn wrap_model_call(
&self,
request: &ModelRequest,
handler: &AsyncModelHandler,
) -> Result<ModelCallResult> {
let response = handler(request).await?;
Ok(ModelCallResult::Response(response))
}
async fn wrap_tool_call(
&self,
tool: &dyn BaseTool,
input: &Value,
handler: &(dyn for<'a, 'b> Fn(&'a dyn BaseTool, &'b Value) -> Result<Value> + Send + Sync),
) -> Result<Value> {
handler(tool, input)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_state_default() {
let state = AgentState::default();
assert!(state.messages.is_empty());
assert!(state.structured_response.is_none());
assert!(state.jump_to.is_none());
assert!(state.extra.is_empty());
}
#[test]
fn test_agent_state_new() {
let msg = Message::human("hello");
let state = AgentState::new(vec![msg]);
assert_eq!(state.messages.len(), 1);
}
#[test]
fn test_agent_state_extra() {
let mut state = AgentState::default();
state.set_extra("count", serde_json::json!(42));
assert_eq!(state.get_extra("count"), Some(&serde_json::json!(42)));
assert_eq!(state.get_extra("missing"), None);
}
#[test]
fn test_agent_state_apply_updates() {
let mut state = AgentState::default();
let mut updates = HashMap::new();
updates.insert("jump_to".into(), serde_json::json!("end"));
updates.insert("my_field".into(), serde_json::json!("value"));
state.apply_updates(updates);
assert_eq!(state.jump_to, Some(JumpTo::End));
assert_eq!(
state.extra.get("my_field"),
Some(&serde_json::json!("value"))
);
}
#[test]
fn test_model_response_new() {
let msg = Message::ai("hello");
let resp = ModelResponse::new(vec![msg]);
assert_eq!(resp.result.len(), 1);
assert!(resp.structured_response.is_none());
}
#[test]
fn test_model_response_with_structured() {
let resp = ModelResponse::new(vec![])
.with_structured_response(serde_json::json!({"name": "test"}));
assert!(resp.structured_response.is_some());
}
#[test]
fn test_normalize_model_call_result_response() {
let resp = ModelResponse::new(vec![Message::ai("hi")]);
let result = normalize_model_call_result(ModelCallResult::Response(resp));
assert_eq!(result.result.len(), 1);
}
#[test]
fn test_normalize_model_call_result_message() {
let result =
normalize_model_call_result(ModelCallResult::Message(Box::new(Message::ai("hi"))));
assert_eq!(result.result.len(), 1);
}
#[test]
fn test_jump_to_serialize() {
assert_eq!(serde_json::to_string(&JumpTo::End).unwrap(), "\"end\"");
assert_eq!(serde_json::to_string(&JumpTo::Tools).unwrap(), "\"tools\"");
}
#[test]
fn test_agent_state_serialization() {
let state = AgentState::new(vec![Message::human("test")]);
let json = serde_json::to_value(&state).unwrap();
assert!(json.get("messages").unwrap().is_array());
assert!(json.get("jump_to").is_none());
}
}