use std::rc::Rc;
use std::sync::Arc;
use crate::value::{VmChannelHandle, VmError, VmStream, VmStreamCancel, VmValue};
use super::api;
use super::call::build_llm_error_dict;
use super::helpers::extract_llm_options;
use super::stream::vm_stream_llm;
pub(super) async fn llm_stream_builtin(args: Vec<VmValue>) -> Result<VmValue, VmError> {
let opts = extract_llm_options(&args)?;
let provider = opts.provider.clone();
let prompt_text = opts
.messages
.last()
.and_then(|m| m["content"].as_str())
.unwrap_or("")
.to_string();
let (tx, rx) = tokio::sync::mpsc::channel::<VmValue>(64);
let closed = Arc::new(std::sync::atomic::AtomicBool::new(false));
let closed_clone = closed.clone();
#[allow(clippy::arc_with_non_send_sync)]
let tx_arc = Arc::new(tx);
let tx_for_task = tx_arc.clone();
tokio::task::spawn_local(async move {
if provider == "mock" {
let words: Vec<&str> = prompt_text.split_whitespace().collect();
for word in &words {
let _ = tx_for_task.send(VmValue::String(Rc::from(*word))).await;
}
closed_clone.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
let result = vm_stream_llm(&opts, &tx_for_task).await;
closed_clone.store(true, std::sync::atomic::Ordering::Relaxed);
if let Err(e) = result {
let _ = tx_for_task
.send(VmValue::String(Rc::from(format!("error: {e}"))))
.await;
}
});
#[allow(clippy::arc_with_non_send_sync)]
let handle = VmChannelHandle {
name: Rc::from("llm_stream"),
sender: tx_arc,
receiver: Arc::new(tokio::sync::Mutex::new(rx)),
closed,
};
Ok(VmValue::Channel(handle))
}
fn llm_stream_chunk(
delta: &str,
visible_delta: &str,
partial: &str,
finish_reason: Option<&str>,
) -> VmValue {
let mut dict = std::collections::BTreeMap::new();
dict.insert(
"delta".to_string(),
VmValue::String(Rc::from(delta.to_string())),
);
dict.insert(
"visible_delta".to_string(),
VmValue::String(Rc::from(visible_delta.to_string())),
);
dict.insert(
"partial".to_string(),
VmValue::String(Rc::from(partial.to_string())),
);
dict.insert("role".to_string(), VmValue::String(Rc::from("assistant")));
dict.insert(
"finish_reason".to_string(),
finish_reason
.map(|reason| VmValue::String(Rc::from(reason.to_string())))
.unwrap_or(VmValue::Nil),
);
VmValue::Dict(Rc::new(dict))
}
async fn forward_llm_stream_delta(
stream_tx: &tokio::sync::mpsc::Sender<Result<VmValue, VmError>>,
visible: &mut crate::visible_text::VisibleTextState,
delta: String,
) -> Result<String, ()> {
let (partial, visible_delta) = visible.push(&delta, true);
let chunk = llm_stream_chunk(&delta, &visible_delta, &partial, None);
stream_tx.send(Ok(chunk)).await.map_err(|_| ())?;
Ok(partial)
}
async fn send_llm_stream_error(
stream_tx: &tokio::sync::mpsc::Sender<Result<VmValue, VmError>>,
err: VmError,
provider: &str,
model: &str,
) {
let wrapped = VmError::Thrown(build_llm_error_dict(&err, provider, model));
let _ = stream_tx.send(Err(wrapped)).await;
}
pub(super) async fn llm_stream_call_impl(args: Vec<VmValue>) -> Result<VmValue, VmError> {
let opts = extract_llm_options(&args)?;
let provider = opts.provider.clone();
let model = opts.model.clone();
let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<Result<VmValue, VmError>>(64);
let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let cancel = VmStreamCancel::new();
let mut cancel_rx = cancel.subscribe();
tokio::task::spawn_local(async move {
let mut visible = crate::visible_text::VisibleTextState::default();
let mut partial = String::new();
let mut deltas_open = true;
let mut llm_task = tokio::task::spawn_local(async move {
api::vm_call_llm_full_streaming(&opts, delta_tx).await
});
loop {
tokio::select! {
_ = cancel_rx.changed() => {
llm_task.abort();
break;
}
_ = stream_tx.closed() => {
llm_task.abort();
break;
}
maybe_delta = delta_rx.recv(), if deltas_open => {
match maybe_delta {
Some(delta) => {
match forward_llm_stream_delta(&stream_tx, &mut visible, delta).await {
Ok(next_partial) => partial = next_partial,
Err(()) => {
llm_task.abort();
break;
}
}
}
None => deltas_open = false,
}
}
joined = &mut llm_task => {
while let Ok(delta) = delta_rx.try_recv() {
match forward_llm_stream_delta(&stream_tx, &mut visible, delta).await {
Ok(next_partial) => partial = next_partial,
Err(()) => break,
}
}
match joined {
Ok(Ok(result)) => {
let final_chunk = llm_stream_chunk(
"",
"",
&partial,
result.stop_reason.as_deref(),
);
let _ = stream_tx.send(Ok(final_chunk)).await;
}
Ok(Err(err)) => {
send_llm_stream_error(&stream_tx, err, &provider, &model).await;
}
Err(join_err) if join_err.is_cancelled() => {}
Err(join_err) => {
let err = VmError::Thrown(VmValue::String(Rc::from(format!(
"llm_stream_call background task failed: {join_err}"
))));
send_llm_stream_error(&stream_tx, err, &provider, &model).await;
}
}
break;
}
}
}
});
Ok(VmValue::Stream(VmStream {
done: Rc::new(std::cell::Cell::new(false)),
receiver: Rc::new(tokio::sync::Mutex::new(stream_rx)),
cancel: Some(cancel),
}))
}