lash-core 0.1.0-alpha.37

Sans-IO turn machine and runtime kernel for the lash agent runtime.
Documentation
use crate::llm::transport::LlmTransportError;
use crate::llm::types::{
    LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
    LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmToolChoice,
};
use crate::provider::ProviderHandle;
use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
use std::sync::Arc;

#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DirectRole {
    System,
    User,
    Assistant,
}

#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum DirectPart {
    Text(String),
    Image(usize),
}

#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct DirectMessage {
    pub role: DirectRole,
    pub parts: Vec<DirectPart>,
}

#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct DirectJsonSchema {
    pub name: String,
    pub schema: serde_json::Value,
    pub strict: bool,
}

#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
pub enum DirectOutputSpec {
    #[default]
    Text,
    JsonObject,
    JsonSchema(DirectJsonSchema),
}

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DirectRequest {
    pub model: String,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub model_variant: Option<String>,
    #[serde(default)]
    pub messages: Vec<DirectMessage>,
    #[serde(default)]
    pub attachments: Vec<LlmAttachment>,
    #[serde(default)]
    pub output: DirectOutputSpec,
    #[serde(default)]
    pub generation: crate::GenerationOptions,
    #[serde(default, skip)]
    pub stream_events: Option<LlmEventSender>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub session_id: Option<String>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub caused_by: Option<crate::CausalRef>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub replay: Option<crate::RuntimeReplay>,
}

impl DirectRequest {
    pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
        Self {
            model: model.into(),
            model_variant: None,
            messages: vec![DirectMessage {
                role: DirectRole::User,
                parts: vec![DirectPart::Text(prompt.into())],
            }],
            attachments: Vec::new(),
            output: DirectOutputSpec::Text,
            generation: crate::GenerationOptions::default(),
            stream_events: None,
            session_id: None,
            caused_by: None,
            replay: None,
        }
    }

    pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
        Self {
            output: DirectOutputSpec::JsonObject,
            ..Self::text(model, prompt)
        }
    }

    pub fn json_schema(
        model: impl Into<String>,
        prompt: impl Into<String>,
        schema: DirectJsonSchema,
    ) -> Self {
        Self {
            output: DirectOutputSpec::JsonSchema(schema),
            ..Self::text(model, prompt)
        }
    }

    pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
        self.replay = Some(crate::RuntimeReplay { key: key.into() });
        self
    }

    pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
        self.caused_by = Some(caused_by);
        self
    }
}

#[derive(Debug, thiserror::Error, Clone)]
pub enum DirectLlmError {
    #[error("invalid request: {0}")]
    InvalidRequest(String),
    #[error("transport error: {0}")]
    Transport(#[from] LlmTransportError),
}

pub struct DirectLlmClient {
    provider: ProviderHandle,
    trace_sink: Option<Arc<dyn TraceSink>>,
    trace_context: TraceContext,
}

impl DirectLlmClient {
    pub fn new(provider: ProviderHandle) -> Self {
        Self {
            provider,
            trace_sink: None,
            trace_context: TraceContext::default(),
        }
    }

    pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
        self.trace_sink = sink;
        self
    }

    pub fn with_trace_context(mut self, context: TraceContext) -> Self {
        self.trace_context = context;
        self
    }

    pub fn provider(&self) -> &ProviderHandle {
        &self.provider
    }

    pub fn provider_mut(&mut self) -> &mut ProviderHandle {
        &mut self.provider
    }

    pub async fn complete(
        &mut self,
        request: DirectRequest,
    ) -> Result<LlmResponse, DirectLlmError> {
        if let Some(variant) = request.model_variant.as_deref() {
            self.provider
                .validate_variant(&request.model, variant)
                .map_err(DirectLlmError::InvalidRequest)?;
        }

        let model = request.model.clone();
        let llm_request = build_llm_request(&self.provider, request, model);
        let llm_call_id = if self.trace_sink.is_some() {
            let id = uuid::Uuid::new_v4().to_string();
            crate::trace::emit_trace(
                &self.trace_sink,
                &self.trace_context,
                TraceContext::default().for_llm_call(id.clone()),
                TraceEvent::LlmCallStarted {
                    request: crate::trace::trace_llm_request(&llm_request),
                },
            );
            Some(id)
        } else {
            None
        };
        match self.provider.complete(llm_request).await {
            Ok(response) => {
                if let Some(llm_call_id) = llm_call_id {
                    crate::trace::emit_trace(
                        &self.trace_sink,
                        &self.trace_context,
                        TraceContext::default().for_llm_call(llm_call_id),
                        TraceEvent::LlmCallCompleted {
                            response: crate::trace::trace_llm_response(
                                response.full_text.clone(),
                                0,
                                Some(response.terminal_reason),
                                crate::trace::trace_output_parts(&response.parts),
                            ),
                            usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
                            provider_usage: response.provider_usage.clone(),
                            stream_summary: None,
                        },
                    );
                }
                Ok(response)
            }
            Err(error) => {
                if let Some(llm_call_id) = llm_call_id {
                    crate::trace::emit_trace(
                        &self.trace_sink,
                        &self.trace_context,
                        TraceContext::default().for_llm_call(llm_call_id),
                        TraceEvent::LlmCallFailed {
                            error: TraceError {
                                message: error.message.clone(),
                                retryable: error.retryable,
                                terminal_reason: Some(error.terminal_reason.code().to_string()),
                                code: error.code.clone(),
                                raw: error.raw.clone(),
                            },
                            stream_summary: None,
                        },
                    );
                }
                Err(DirectLlmError::from(error))
            }
        }
    }
}

pub(crate) fn build_llm_request(
    provider: &ProviderHandle,
    request: DirectRequest,
    model: String,
) -> LlmRequest {
    let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
    let DirectRequest {
        model: _,
        model_variant,
        messages,
        attachments,
        output,
        generation,
        stream_events: _,
        session_id,
        caused_by: _,
        replay: _,
    } = request;

    let output_spec = match output {
        DirectOutputSpec::Text => None,
        DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
        DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
            name: schema.name,
            schema: schema.schema,
            strict: schema.strict,
        })),
    };

    let mut llm_messages = Vec::new();
    for message in messages {
        let role = match message.role {
            DirectRole::System => LlmRole::System,
            DirectRole::User => LlmRole::User,
            DirectRole::Assistant => LlmRole::Assistant,
        };
        let mut blocks: Vec<LlmContentBlock> = Vec::new();
        for part in message.parts {
            match part {
                DirectPart::Text(text) => {
                    if !text.is_empty() {
                        blocks.push(LlmContentBlock::Text {
                            text: text.into(),
                            response_meta: None,
                            cache_breakpoint: false,
                        });
                    }
                }
                DirectPart::Image(idx) => {
                    blocks.push(LlmContentBlock::Image {
                        attachment_idx: idx,
                    });
                }
            }
        }
        if !blocks.is_empty() {
            llm_messages.push(LlmMessage::new(role, blocks));
        }
    }

    LlmRequest {
        model,
        messages: llm_messages,
        attachments,
        tools: Vec::new().into(),
        tool_choice: LlmToolChoice::None,
        model_variant,
        generation,
        session_id,
        output_spec,
        stream_events,
        provider_trace: None,
    }
}

fn transport_stream_events_for_direct(
    provider: &ProviderHandle,
    requested: Option<LlmEventSender>,
) -> Option<LlmEventSender> {
    if requested.is_some() {
        return requested;
    }
    if provider.requires_streaming() {
        Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
    } else {
        None
    }
}