rune-chain-core 0.1.0

Core traits and types for the rune-chain LLM orchestration framework
Documentation
use std::{collections::HashMap, pin::Pin};

use async_trait::async_trait;
use futures::Stream;
use serde_json::Value;

use crate::{ChainError, GenerateResult, PromptArgs, StreamData};

const DEFAULT_OUTPUT_KEY: &str = "output";
const DEFAULT_RESULT_KEY: &str = "generate_result";

/// A composable unit of LLM work that maps [`PromptArgs`] to a [`GenerateResult`].
///
/// Chains are the core abstraction of the `rune-chain` ecosystem. A chain may wrap
/// a single LLM call, a retrieval step, a sequence of sub-chains, or any other
/// unit of work that accepts named variables and produces generated text.
///
/// Implementors must provide [`Chain::call`]; all other methods have default
/// implementations built on top of it.
///
/// # Example
///
/// ```rust,ignore
/// use rune_chain_core::{Chain, ChainError, GenerateResult, PromptArgs, prompt_args};
/// use async_trait::async_trait;
///
/// struct EchoChain;
///
/// #[async_trait]
/// impl Chain for EchoChain {
///     async fn call(&self, input: PromptArgs) -> Result<GenerateResult, ChainError> {
///         let text = input["input"].as_str().unwrap_or("").to_string();
///         Ok(GenerateResult::from_text(text))
///     }
/// }
///
/// # tokio_test::block_on(async {
/// let chain = EchoChain;
/// let result = chain.invoke(prompt_args! { "input" => "hello" }).await.unwrap();
/// assert_eq!(result, "hello");
/// # });
/// ```
#[async_trait]
pub trait Chain: Sync + Send {
    /// Run the chain and return the full [`GenerateResult`] including token usage.
    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError>;

    /// Run the chain and return the generated text string only.
    ///
    /// Convenience wrapper around [`Chain::call`] that discards token usage.
    async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
        self.call(input_variables).await.map(|r| r.generation)
    }

    /// Run the chain and return a named-output map ready to pipe into the next step.
    ///
    /// The returned map contains:
    /// - the key from [`Chain::output_keys`] (default: `"output"`) → generated text
    /// - `"generate_result"` → the full serialised [`GenerateResult`]
    async fn execute(
        &self,
        input_variables: PromptArgs,
    ) -> Result<HashMap<String, Value>, ChainError> {
        let result = self.call(input_variables).await?;
        let output_key = self
            .output_keys()
            .into_iter()
            .next()
            .unwrap_or_else(|| DEFAULT_OUTPUT_KEY.to_string());
        let mut map = HashMap::new();
        map.insert(output_key, Value::String(result.generation.clone()));
        map.insert(
            DEFAULT_RESULT_KEY.to_string(),
            serde_json::to_value(&result).unwrap_or(Value::Null),
        );
        Ok(map)
    }

    /// Stream tokens as they are produced rather than waiting for the full completion.
    ///
    /// Returns a [`Stream`] of [`StreamData`] chunks. Chains that do not support
    /// streaming return [`ChainError::Other`] by default.
    ///
    /// > [!NOTE]
    /// > Memory layers cannot be updated automatically during a stream because the
    /// > complete output is only known when the stream ends. Callers are responsible
    /// > for persisting the accumulated output after draining the stream.
    async fn stream(
        &self,
        _input_variables: PromptArgs,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
    {
        Err(ChainError::Other(
            "streaming is not implemented for this chain".into(),
        ))
    }

    /// The variable names this chain reads from its [`PromptArgs`] input.
    fn input_keys(&self) -> Vec<String> {
        vec![]
    }

    /// The keys this chain writes into the map returned by [`Chain::execute`].
    fn output_keys(&self) -> Vec<String> {
        vec![
            DEFAULT_OUTPUT_KEY.to_string(),
            DEFAULT_RESULT_KEY.to_string(),
        ]
    }
}

impl<C: Chain + 'static> From<C> for Box<dyn Chain> {
    fn from(chain: C) -> Self {
        Box::new(chain)
    }
}