use async_trait::async_trait;
use futures_util::Stream;
use std::pin::Pin;
use super::RunnableConfig;
#[async_trait]
pub trait Runnable<Input: Send + Sync + 'static, Output: Send + Sync + 'static>: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
async fn invoke(&self, input: Input, config: Option<RunnableConfig>) -> Result<Output, Self::Error>;
async fn batch(
&self,
inputs: Vec<Input>,
config: Option<RunnableConfig>,
) -> Result<Vec<Output>, Self::Error> {
let mut results = Vec::with_capacity(inputs.len());
for input in inputs {
let result = self.invoke(input, config.clone()).await?;
results.push(result);
}
Ok(results)
}
async fn stream(
&self,
input: Input,
config: Option<RunnableConfig>,
) -> Result<Pin<Box<dyn Stream<Item = Result<Output, Self::Error>> + Send>>, Self::Error> {
let result = self.invoke(input, config).await?;
let stream = futures_util::stream::once(async move { Ok(result) });
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
struct TestRunnable;
#[async_trait]
impl Runnable<String, String> for TestRunnable {
type Error = std::convert::Infallible;
async fn invoke(&self, input: String, _config: Option<RunnableConfig>) -> Result<String, Self::Error> {
Ok(format!("processed: {}", input))
}
}
#[tokio::test]
async fn test_default_stream_returns_single_element() {
let runnable = TestRunnable;
let mut stream = runnable.stream("test".to_string(), None).await.unwrap();
let first = stream.next().await;
assert!(first.is_some());
assert_eq!(first.unwrap().unwrap(), "processed: test");
let second = stream.next().await;
assert!(second.is_none());
}
#[tokio::test]
async fn test_invoke_matches_stream_result() {
let runnable = TestRunnable;
let invoke_result = runnable.invoke("hello".to_string(), None).await.unwrap();
let mut stream = runnable.stream("hello".to_string(), None).await.unwrap();
let stream_result = stream.next().await.unwrap().unwrap();
assert_eq!(invoke_result, stream_result);
}
}