use langchainrust::{Runnable, RunnableConfig};
use async_trait::async_trait;
use futures_util::StreamExt;
struct StringProcessor;
#[async_trait]
impl Runnable<String, String> for StringProcessor {
type Error = std::convert::Infallible;
async fn invoke(&self, input: String, _config: Option<RunnableConfig>) -> Result<String, Self::Error> {
Ok(format!("processed: {}", input))
}
}
struct NumberDoubler;
#[async_trait]
impl Runnable<i32, i32> for NumberDoubler {
type Error = std::convert::Infallible;
async fn invoke(&self, input: i32, _config: Option<RunnableConfig>) -> Result<i32, Self::Error> {
Ok(input * 2)
}
}
struct FailingRunnable;
#[derive(Debug)]
struct FailingError(String);
impl std::fmt::Display for FailingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for FailingError {}
#[async_trait]
impl Runnable<String, String> for FailingRunnable {
type Error = FailingError;
async fn invoke(&self, _input: String, _config: Option<RunnableConfig>) -> Result<String, Self::Error> {
Err(FailingError("invoke failed".to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_default_stream_returns_single_element() {
let runnable = StringProcessor;
let mut stream = runnable.stream("test_input".to_string(), None).await.unwrap();
let first = stream.next().await;
assert!(first.is_some(), "流应有第一个元素");
assert_eq!(first.unwrap().unwrap(), "processed: test_input");
let second = stream.next().await;
assert!(second.is_none(), "流应只有一个元素");
}
#[tokio::test]
async fn test_invoke_equals_stream_result() {
let runnable = StringProcessor;
let input = "hello world".to_string();
let invoke_result = runnable.invoke(input.clone(), None).await.unwrap();
let mut stream = runnable.stream(input.clone(), None).await.unwrap();
let stream_result = stream.next().await.unwrap().unwrap();
assert_eq!(invoke_result, stream_result);
}
#[tokio::test]
async fn test_stream_works_for_different_types() {
let runnable = NumberDoubler;
let invoke_result = runnable.invoke(5, None).await.unwrap();
let mut stream = runnable.stream(5, None).await.unwrap();
let stream_result = stream.next().await.unwrap().unwrap();
assert_eq!(invoke_result, 10);
assert_eq!(stream_result, 10);
}
#[tokio::test]
async fn test_stream_passes_config_to_invoke() {
let runnable = StringProcessor;
let config = RunnableConfig::new()
.with_tag("test_tag")
.with_run_name("test_run");
let mut stream = runnable.stream("input".to_string(), Some(config)).await.unwrap();
let result = stream.next().await.unwrap().unwrap();
assert_eq!(result, "processed: input");
}
#[tokio::test]
async fn test_stream_propagates_invoke_error() {
let runnable = FailingRunnable;
let invoke_result = runnable.invoke("test".to_string(), None).await;
assert!(invoke_result.is_err());
let stream_result = runnable.stream("test".to_string(), None).await;
assert!(stream_result.is_err(), "stream 应返回 invoke 的错误");
}
#[tokio::test]
async fn test_batch_consistency_with_stream() {
let runnable = StringProcessor;
let inputs = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let batch_results = runnable.batch(inputs.clone(), None).await.unwrap();
for (i, input) in inputs.iter().enumerate() {
let mut stream = runnable.stream(input.clone(), None).await.unwrap();
let stream_result = stream.next().await.unwrap().unwrap();
assert_eq!(batch_results[i], stream_result);
}
}
#[tokio::test]
async fn test_stream_can_be_collected() {
let runnable = StringProcessor;
let stream = runnable.stream("collected".to_string(), None).await.unwrap();
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].as_ref().unwrap(), "processed: collected");
}
#[tokio::test]
async fn test_multiple_stream_calls_on_same_instance() {
let runnable = StringProcessor;
for i in 0..5 {
let input = format!("call_{}", i);
let mut stream = runnable.stream(input.clone(), None).await.unwrap();
let result = stream.next().await.unwrap().unwrap();
assert_eq!(result, format!("processed: call_{}", i));
}
}
}