use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use crate::chat_history::BaseChatMessageHistory;
use crate::error::{CognisError, Result};
use crate::messages::Message;
use crate::runnables::base::Runnable;
use crate::runnables::config::RunnableConfig;
use crate::runnables::RunnableStream;
#[derive(Debug, Clone)]
pub struct ConfigurableFieldSpec {
pub id: String,
pub name: Option<String>,
pub description: Option<String>,
pub default: Option<String>,
pub is_shared: bool,
}
impl ConfigurableFieldSpec {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
name: None,
description: None,
default: None,
is_shared: false,
}
}
}
type SessionHistoryFactory = Box<dyn Fn(&str) -> Arc<dyn BaseChatMessageHistory> + Send + Sync>;
pub struct RunnableWithMessageHistory {
runnable: Arc<dyn Runnable>,
get_session_history: SessionHistoryFactory,
input_messages_key: Option<String>,
output_messages_key: Option<String>,
history_messages_key: Option<String>,
history_factory_config: Vec<ConfigurableFieldSpec>,
}
impl RunnableWithMessageHistory {
pub fn new(
runnable: Arc<dyn Runnable>,
get_session_history: impl Fn(&str) -> Arc<dyn BaseChatMessageHistory> + Send + Sync + 'static,
) -> Self {
Self {
runnable,
get_session_history: Box::new(get_session_history),
input_messages_key: None,
output_messages_key: None,
history_messages_key: None,
history_factory_config: vec![ConfigurableFieldSpec::new("session_id")],
}
}
pub fn with_input_messages_key(mut self, key: impl Into<String>) -> Self {
self.input_messages_key = Some(key.into());
self
}
pub fn with_output_messages_key(mut self, key: impl Into<String>) -> Self {
self.output_messages_key = Some(key.into());
self
}
pub fn with_history_messages_key(mut self, key: impl Into<String>) -> Self {
self.history_messages_key = Some(key.into());
self
}
pub fn with_history_factory_config(mut self, config: Vec<ConfigurableFieldSpec>) -> Self {
self.history_factory_config = config;
self
}
fn get_session_id(&self, config: Option<&RunnableConfig>) -> Result<String> {
let config = config.ok_or_else(|| {
CognisError::Other("RunnableConfig is required with session_id in configurable".into())
})?;
for spec in &self.history_factory_config {
if let Some(val) = config.configurable.get(&spec.id) {
if let Some(s) = val.as_str() {
return Ok(s.to_string());
}
}
if let Some(ref default) = spec.default {
return Ok(default.clone());
}
}
Err(CognisError::Other(
"session_id not found in RunnableConfig configurable".into(),
))
}
}
#[async_trait]
impl Runnable for RunnableWithMessageHistory {
fn name(&self) -> &str {
"RunnableWithMessageHistory"
}
async fn invoke(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value> {
let session_id = self.get_session_id(config)?;
let history = (self.get_session_history)(&session_id);
let existing_messages = history.messages().await?;
let mut enriched_input = input.clone();
if let Value::Object(ref mut map) = enriched_input {
let history_key = self.history_messages_key.as_deref().unwrap_or("history");
let history_value = serde_json::to_value(&existing_messages)?;
map.insert(history_key.to_string(), history_value);
}
let output = self.runnable.invoke(enriched_input, config).await?;
if let Value::Object(ref input_map) = input {
let input_key = self.input_messages_key.as_deref().unwrap_or("input");
if let Some(input_val) = input_map.get(input_key) {
if let Ok(msgs) = serde_json::from_value::<Vec<Message>>(input_val.clone()) {
if !msgs.is_empty() {
history.add_messages(msgs).await?;
}
}
}
}
if let Value::Object(ref output_map) = output {
let output_key = self.output_messages_key.as_deref().unwrap_or("output");
if let Some(output_val) = output_map.get(output_key) {
if let Ok(msgs) = serde_json::from_value::<Vec<Message>>(output_val.clone()) {
if !msgs.is_empty() {
history.add_messages(msgs).await?;
}
}
}
}
Ok(output)
}
async fn stream(
&self,
input: Value,
config: Option<&RunnableConfig>,
) -> Result<RunnableStream> {
let result = self.invoke(input, config).await;
Ok(Box::pin(futures::stream::once(async { result })))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chat_history::InMemoryChatMessageHistory;
use serde_json::json;
use std::collections::HashMap;
struct EchoRunnable;
#[async_trait]
impl Runnable for EchoRunnable {
fn name(&self) -> &str {
"EchoRunnable"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
Ok(input)
}
}
#[tokio::test]
async fn test_session_id_required() {
let runnable = RunnableWithMessageHistory::new(Arc::new(EchoRunnable), |_| {
Arc::new(InMemoryChatMessageHistory::new())
});
let result = runnable.invoke(json!({"input": "hello"}), None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_invoke_with_session() {
let runnable = RunnableWithMessageHistory::new(Arc::new(EchoRunnable), |_| {
Arc::new(InMemoryChatMessageHistory::new())
});
let mut configurable = HashMap::new();
configurable.insert(
"session_id".to_string(),
Value::String("test-session".to_string()),
);
let config = RunnableConfig {
configurable,
..RunnableConfig::default()
};
let result = runnable
.invoke(json!({"input": "hello"}), Some(&config))
.await;
assert!(result.is_ok());
}
}