use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::{Result, CognisError};
use crate::messages::{AIMessage, HumanMessage, Message};
use crate::outputs::{ChatGenerationChunk, ChatResult};
use crate::runnables::base::Runnable;
use crate::runnables::config::RunnableConfig;
use crate::tools::ToolSchema;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StructuredOutputModel {
pub schema: Value,
pub method: String,
pub include_raw: bool,
}
#[derive(Debug, Clone)]
pub enum ToolChoice {
Auto,
Any,
Tool(String),
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamingMode {
Always,
Never,
SkipToolCalling,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelProfile {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_input_tokens: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text_inputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_inputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_url_inputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pdf_inputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub audio_inputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub video_inputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_tool_message: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pdf_tool_message: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning_output: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text_outputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_outputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub audio_outputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub video_outputs: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calling: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub structured_output: Option<bool>,
}
pub type ModelProfileRegistry = HashMap<String, ModelProfile>;
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatGenerationChunk>> + Send>>;
#[async_trait]
pub trait BaseChatModel: Send + Sync {
async fn _generate(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatResult>;
fn llm_type(&self) -> &str;
async fn _stream(&self, _messages: &[Message], _stop: Option<&[String]>) -> Result<ChatStream> {
Err(CognisError::NotImplemented(
"Streaming not supported for this chat model".into(),
))
}
fn bind_tools(
&self,
_tools: &[ToolSchema],
_tool_choice: Option<ToolChoice>,
) -> Result<Box<dyn BaseChatModel>> {
Err(CognisError::NotImplemented(format!(
"{} does not support tool binding",
self.llm_type()
)))
}
fn profile(&self) -> ModelProfile {
ModelProfile::default()
}
fn get_num_tokens_from_messages(&self, messages: &[Message]) -> usize {
messages.iter().map(|m| m.content().text().len() / 4).sum()
}
async fn with_structured_output(
&self,
schema: Value,
method: Option<&str>,
include_raw: bool,
) -> Result<StructuredOutputModel> {
Ok(StructuredOutputModel {
schema,
method: method.unwrap_or("tool_calling").to_string(),
include_raw,
})
}
async fn generate(
&self,
message_batches: &[Vec<Message>],
stop: Option<&[String]>,
) -> Result<Vec<ChatResult>> {
let mut results = Vec::with_capacity(message_batches.len());
for messages in message_batches {
results.push(self._generate(messages, stop).await?);
}
Ok(results)
}
async fn invoke_messages(
&self,
messages: &[Message],
stop: Option<&[String]>,
) -> Result<AIMessage> {
let result = self._generate(messages, stop).await?;
let gen = result
.generations
.into_iter()
.next()
.ok_or_else(|| CognisError::Other("No generations returned".into()))?;
match gen.message {
Message::Ai(ai_msg) => Ok(ai_msg),
_ => Err(CognisError::Other(
"Expected AIMessage in ChatGeneration, got a different message type".into(),
)),
}
}
}
pub struct ChatModelRunnable {
model: Arc<dyn BaseChatModel>,
name: String,
}
impl ChatModelRunnable {
pub fn new(model: Arc<dyn BaseChatModel>) -> Self {
let name = format!("ChatModelRunnable({})", model.llm_type());
Self { model, name }
}
}
#[async_trait]
impl Runnable for ChatModelRunnable {
fn name(&self) -> &str {
&self.name
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
let messages = parse_chat_input(input)?;
let ai_msg = self.model.invoke_messages(&messages, None).await?;
serde_json::to_value(&ai_msg).map_err(|e| CognisError::Other(e.to_string()))
}
async fn stream(
&self,
input: Value,
_config: Option<&RunnableConfig>,
) -> Result<crate::runnables::RunnableStream> {
let messages = parse_chat_input(input)?;
let chat_stream = self.model._stream(&messages, None).await?;
use futures::StreamExt;
let mapped = chat_stream.map(|chunk_result| {
chunk_result.and_then(|chunk| {
serde_json::to_value(&chunk).map_err(|e| CognisError::Other(e.to_string()))
})
});
Ok(Box::pin(mapped))
}
}
fn parse_chat_input(input: Value) -> Result<Vec<Message>> {
match input {
Value::String(s) => Ok(vec![Message::Human(HumanMessage::new(&s))]),
Value::Array(_) => serde_json::from_value(input)
.map_err(|e| CognisError::Other(format!("Failed to deserialize messages: {e}"))),
Value::Object(ref map) => {
if let Some(msgs) = map.get("messages") {
serde_json::from_value(msgs.clone()).map_err(|e| {
CognisError::Other(format!("Failed to deserialize messages: {e}"))
})
} else if let Some(text) = map.get("input").and_then(|v| v.as_str()) {
Ok(vec![Message::Human(HumanMessage::new(text))])
} else {
Err(CognisError::TypeMismatch {
expected: "String, Array, or Object with 'messages'/'input'".into(),
got: "Object without recognized keys".into(),
})
}
}
_ => Err(CognisError::TypeMismatch {
expected: "String or Array of Messages".into(),
got: format!("{}", input),
}),
}
}