rig_volcengine/
streaming.rs1use rig::completion::{CompletionError, CompletionRequest};
2use rig::providers::openai::send_compatible_streaming_request;
3use rig::streaming::StreamingCompletionResponse;
4use serde_json::json;
5use tracing::info_span;
6
7use super::completion::CompletionModel;
8
9fn merge(left: serde_json::Value, right: serde_json::Value) -> serde_json::Value {
11 match (left, right) {
12 (serde_json::Value::Object(mut a), serde_json::Value::Object(b)) => {
13 for (k, v) in b {
14 let merged = match a.remove(&k) {
15 Some(existing) => merge(existing, v),
16 None => v,
17 };
18 a.insert(k, merged);
19 }
20 serde_json::Value::Object(a)
21 }
22 (_, r) => r,
23 }
24}
25
26pub(crate) async fn stream_completion(
27 model: &CompletionModel<reqwest::Client>,
28 request: CompletionRequest,
29) -> Result<
30 StreamingCompletionResponse<
31 <CompletionModel<reqwest::Client> as rig::completion::CompletionModel>::StreamingResponse,
32 >,
33 CompletionError,
34> {
35 let preamble = request.preamble.clone();
36 let mut request = model.create_completion_request(request)?;
37
38 request = merge(
40 request,
41 json!({"stream": true, "stream_options": {"include_usage": true}}),
42 );
43
44 let req = model
45 .client
46 .post("/chat/completions")?
47 .body(serde_json::to_vec(&request)?)
48 .map_err(|e| CompletionError::HttpError(e.into()))?;
49
50 let span = if tracing::Span::current().is_disabled() {
51 info_span!(
52 target: "rig::completions",
53 "chat_streaming",
54 gen_ai.operation.name = "chat_streaming",
55 gen_ai.provider.name = "volcengine",
56 gen_ai.request.model = model.model,
57 gen_ai.system_instructions = preamble,
58 gen_ai.response.id = tracing::field::Empty,
59 gen_ai.response.model = tracing::field::Empty,
60 gen_ai.usage.output_tokens = tracing::field::Empty,
61 gen_ai.usage.input_tokens = tracing::field::Empty,
62 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap_or(&json!([]))).unwrap(),
63 gen_ai.output.messages = tracing::field::Empty,
64 )
65 } else {
66 tracing::Span::current()
67 };
68
69 tracing::Instrument::instrument(
70 send_compatible_streaming_request(model.client.http_client.clone(), req),
71 span,
72 )
73 .await
74}