use std::{collections::HashMap, pin::Pin};
use async_trait::async_trait;
use futures::Stream;
use serde_json::Value;
use crate::{ChainError, GenerateResult, PromptArgs, StreamData};
const DEFAULT_OUTPUT_KEY: &str = "output";
const DEFAULT_RESULT_KEY: &str = "generate_result";
#[async_trait]
pub trait Chain: Sync + Send {
async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError>;
async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
self.call(input_variables).await.map(|r| r.generation)
}
async fn execute(
&self,
input_variables: PromptArgs,
) -> Result<HashMap<String, Value>, ChainError> {
let result = self.call(input_variables).await?;
let output_key = self
.output_keys()
.into_iter()
.next()
.unwrap_or_else(|| DEFAULT_OUTPUT_KEY.to_string());
let mut map = HashMap::new();
map.insert(output_key, Value::String(result.generation.clone()));
map.insert(
DEFAULT_RESULT_KEY.to_string(),
serde_json::to_value(&result).unwrap_or(Value::Null),
);
Ok(map)
}
async fn stream(
&self,
_input_variables: PromptArgs,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
{
Err(ChainError::Other(
"streaming is not implemented for this chain".into(),
))
}
fn input_keys(&self) -> Vec<String> {
vec![]
}
fn output_keys(&self) -> Vec<String> {
vec![
DEFAULT_OUTPUT_KEY.to_string(),
DEFAULT_RESULT_KEY.to_string(),
]
}
}
impl<C: Chain + 'static> From<C> for Box<dyn Chain> {
fn from(chain: C) -> Self {
Box::new(chain)
}
}