harn-vm 0.8.22

Async bytecode virtual machine for the Harn programming language
Documentation
use std::rc::Rc;
use std::sync::Arc;

use crate::value::{VmChannelHandle, VmError, VmStream, VmStreamCancel, VmValue};

use super::api;
use super::call::build_llm_error_dict;
use super::helpers::extract_llm_options;
use super::stream::vm_stream_llm;

pub(super) async fn llm_stream_builtin(args: Vec<VmValue>) -> Result<VmValue, VmError> {
    let opts = extract_llm_options(&args)?;
    let provider = opts.provider.clone();
    let prompt_text = opts
        .messages
        .last()
        .and_then(|m| m["content"].as_str())
        .unwrap_or("")
        .to_string();

    let (tx, rx) = tokio::sync::mpsc::channel::<VmValue>(64);
    let closed = Arc::new(std::sync::atomic::AtomicBool::new(false));
    let closed_clone = closed.clone();
    #[allow(clippy::arc_with_non_send_sync)]
    let tx_arc = Arc::new(tx);
    let tx_for_task = tx_arc.clone();

    tokio::task::spawn_local(async move {
        if provider == "mock" {
            let words: Vec<&str> = prompt_text.split_whitespace().collect();
            for word in &words {
                let _ = tx_for_task.send(VmValue::String(Rc::from(*word))).await;
            }
            closed_clone.store(true, std::sync::atomic::Ordering::Relaxed);
            return;
        }

        let result = vm_stream_llm(&opts, &tx_for_task).await;
        closed_clone.store(true, std::sync::atomic::Ordering::Relaxed);
        if let Err(e) = result {
            let _ = tx_for_task
                .send(VmValue::String(Rc::from(format!("error: {e}"))))
                .await;
        }
    });

    #[allow(clippy::arc_with_non_send_sync)]
    let handle = VmChannelHandle {
        name: Rc::from("llm_stream"),
        sender: tx_arc,
        receiver: Arc::new(tokio::sync::Mutex::new(rx)),
        closed,
    };
    Ok(VmValue::Channel(handle))
}

fn llm_stream_chunk(
    delta: &str,
    visible_delta: &str,
    partial: &str,
    finish_reason: Option<&str>,
) -> VmValue {
    let mut dict = std::collections::BTreeMap::new();
    dict.insert(
        "delta".to_string(),
        VmValue::String(Rc::from(delta.to_string())),
    );
    dict.insert(
        "visible_delta".to_string(),
        VmValue::String(Rc::from(visible_delta.to_string())),
    );
    dict.insert(
        "partial".to_string(),
        VmValue::String(Rc::from(partial.to_string())),
    );
    dict.insert("role".to_string(), VmValue::String(Rc::from("assistant")));
    dict.insert(
        "finish_reason".to_string(),
        finish_reason
            .map(|reason| VmValue::String(Rc::from(reason.to_string())))
            .unwrap_or(VmValue::Nil),
    );
    VmValue::Dict(Rc::new(dict))
}

async fn forward_llm_stream_delta(
    stream_tx: &tokio::sync::mpsc::Sender<Result<VmValue, VmError>>,
    visible: &mut crate::visible_text::VisibleTextState,
    delta: String,
) -> Result<String, ()> {
    let (partial, visible_delta) = visible.push(&delta, true);
    let chunk = llm_stream_chunk(&delta, &visible_delta, &partial, None);
    stream_tx.send(Ok(chunk)).await.map_err(|_| ())?;
    Ok(partial)
}

async fn send_llm_stream_error(
    stream_tx: &tokio::sync::mpsc::Sender<Result<VmValue, VmError>>,
    err: VmError,
    provider: &str,
    model: &str,
) {
    let wrapped = VmError::Thrown(build_llm_error_dict(&err, provider, model));
    let _ = stream_tx.send(Err(wrapped)).await;
}

/// Shared implementation of `llm_stream_call`: a first-class `Stream`
/// of structured chunks using `llm_call`'s provider error taxonomy.
pub(super) async fn llm_stream_call_impl(args: Vec<VmValue>) -> Result<VmValue, VmError> {
    let opts = extract_llm_options(&args)?;
    let provider = opts.provider.clone();
    let model = opts.model.clone();

    let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<Result<VmValue, VmError>>(64);
    let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
    let cancel = VmStreamCancel::new();
    let mut cancel_rx = cancel.subscribe();

    tokio::task::spawn_local(async move {
        let mut visible = crate::visible_text::VisibleTextState::default();
        let mut partial = String::new();
        let mut deltas_open = true;
        let mut llm_task = tokio::task::spawn_local(async move {
            api::vm_call_llm_full_streaming(&opts, delta_tx).await
        });

        loop {
            tokio::select! {
                _ = cancel_rx.changed() => {
                    llm_task.abort();
                    break;
                }
                _ = stream_tx.closed() => {
                    llm_task.abort();
                    break;
                }
                maybe_delta = delta_rx.recv(), if deltas_open => {
                    match maybe_delta {
                        Some(delta) => {
                            match forward_llm_stream_delta(&stream_tx, &mut visible, delta).await {
                                Ok(next_partial) => partial = next_partial,
                                Err(()) => {
                                    llm_task.abort();
                                    break;
                                }
                            }
                        }
                        None => deltas_open = false,
                    }
                }
                joined = &mut llm_task => {
                    while let Ok(delta) = delta_rx.try_recv() {
                        match forward_llm_stream_delta(&stream_tx, &mut visible, delta).await {
                            Ok(next_partial) => partial = next_partial,
                            Err(()) => break,
                        }
                    }
                    match joined {
                        Ok(Ok(result)) => {
                            let final_chunk = llm_stream_chunk(
                                "",
                                "",
                                &partial,
                                result.stop_reason.as_deref(),
                            );
                            let _ = stream_tx.send(Ok(final_chunk)).await;
                        }
                        Ok(Err(err)) => {
                            send_llm_stream_error(&stream_tx, err, &provider, &model).await;
                        }
                        Err(join_err) if join_err.is_cancelled() => {}
                        Err(join_err) => {
                            let err = VmError::Thrown(VmValue::String(Rc::from(format!(
                                "llm_stream_call background task failed: {join_err}"
                            ))));
                            send_llm_stream_error(&stream_tx, err, &provider, &model).await;
                        }
                    }
                    break;
                }
            }
        }
    });

    Ok(VmValue::Stream(VmStream {
        done: Rc::new(std::cell::Cell::new(false)),
        receiver: Rc::new(tokio::sync::Mutex::new(stream_rx)),
        cancel: Some(cancel),
    }))
}