rig_volcengine/
streaming.rs

1use rig::completion::{CompletionError, CompletionRequest};
2use rig::providers::openai::send_compatible_streaming_request;
3use rig::streaming::StreamingCompletionResponse;
4use serde_json::json;
5use tracing::info_span;
6
7use super::completion::CompletionModel;
8
9/// Local deep-merge helper (same rule as in completion.rs)
10fn merge(left: serde_json::Value, right: serde_json::Value) -> serde_json::Value {
11    match (left, right) {
12        (serde_json::Value::Object(mut a), serde_json::Value::Object(b)) => {
13            for (k, v) in b {
14                let merged = match a.remove(&k) {
15                    Some(existing) => merge(existing, v),
16                    None => v,
17                };
18                a.insert(k, merged);
19            }
20            serde_json::Value::Object(a)
21        }
22        (_, r) => r,
23    }
24}
25
26pub(crate) async fn stream_completion<T>(
27    model: &CompletionModel<T>,
28    request: CompletionRequest,
29) -> Result<
30    StreamingCompletionResponse<
31        <CompletionModel<T> as rig::completion::CompletionModel>::StreamingResponse,
32    >,
33    CompletionError,
34>
35where
36    T: rig::http_client::HttpClientExt + Clone + Default + Send + 'static,
37{
38    let preamble = request.preamble.clone();
39    let mut request = model.create_completion_request(request)?;
40
41    // Ark chat streaming: OpenAI-compatible flags
42    request = merge(
43        request,
44        json!({"stream": true, "stream_options": {"include_usage": true}}),
45    );
46
47    let req = model
48        .client
49        .post("/chat/completions")?
50        .header("Content-Type", "application/json")
51        .body(serde_json::to_vec(&request)?)
52        .map_err(|e| CompletionError::HttpError(e.into()))?;
53
54    let span = if tracing::Span::current().is_disabled() {
55        info_span!(
56            target: "rig::completions",
57            "chat_streaming",
58            gen_ai.operation.name = "chat_streaming",
59            gen_ai.provider.name = "volcengine",
60            gen_ai.request.model = model.model,
61            gen_ai.system_instructions = preamble,
62            gen_ai.response.id = tracing::field::Empty,
63            gen_ai.response.model = tracing::field::Empty,
64            gen_ai.usage.output_tokens = tracing::field::Empty,
65            gen_ai.usage.input_tokens = tracing::field::Empty,
66            gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap_or(&json!([]))).unwrap(),
67            gen_ai.output.messages = tracing::field::Empty,
68        )
69    } else {
70        tracing::Span::current()
71    };
72
73    tracing::Instrument::instrument(
74        send_compatible_streaming_request(model.client.http_client.clone(), req),
75        span,
76    )
77    .await
78}