1use std::{
2 fmt::{Debug, Display},
3 ops::{Deref, DerefMut},
4 path::{Path, PathBuf},
5 str::FromStr,
6 sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9 },
10 time::Duration,
11};
12
13use async_openai::{
14 Client,
15 config::{AzureConfig, OpenAIConfig},
16 error::OpenAIError,
17 types::chat::{
18 ChatChoice, ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
19 ChatCompletionNamedToolChoiceCustom, ChatCompletionRequestAssistantMessageContent,
20 ChatCompletionRequestAssistantMessageContentPart,
21 ChatCompletionRequestDeveloperMessageContent,
22 ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
23 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
24 ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
25 ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
26 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
27 ChatCompletionResponseMessage, ChatCompletionResponseStream, ChatCompletionStreamOptions,
28 ChatCompletionToolChoiceOption, ChatCompletionTools, CompletionUsage,
29 CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
30 CreateChatCompletionStreamResponse, CustomName, FinishReason, FunctionCall, Role,
31 ToolChoiceOptions,
32 },
33};
34use clap::Args;
35use color_eyre::{
36 Result,
37 eyre::{OptionExt, eyre},
38};
39use futures_util::StreamExt;
40use itertools::Itertools;
41use log::{debug, info, trace, warn};
42use serde::{Deserialize, Serialize};
43use tokio::{io::AsyncWriteExt, sync::RwLock};
44
45use crate::{OpenAIModel, error::PromptError};
46
47#[derive(Clone, Debug, Default)]
48struct ToolCallAcc {
49 id: String,
50 name: String,
51 arguments: String,
52}
53
54#[derive(Debug, Clone)]
56pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
57
58impl FromStr for LLMToolChoice {
59 type Err = PromptError;
60 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
61 Ok(match s {
62 "auto" => Self(ChatCompletionToolChoiceOption::Mode(
63 ToolChoiceOptions::Auto,
64 )),
65 "required" => Self(ChatCompletionToolChoiceOption::Mode(
66 ToolChoiceOptions::Required,
67 )),
68 "none" => Self(ChatCompletionToolChoiceOption::Mode(
69 ToolChoiceOptions::None,
70 )),
71 _ => Self(ChatCompletionToolChoiceOption::Custom(
72 ChatCompletionNamedToolChoiceCustom {
73 custom: CustomName {
74 name: s.to_string(),
75 },
76 },
77 )),
78 })
79 }
80}
81
82impl Deref for LLMToolChoice {
83 type Target = ChatCompletionToolChoiceOption;
84 fn deref(&self) -> &Self::Target {
85 &self.0
86 }
87}
88
89impl DerefMut for LLMToolChoice {
90 fn deref_mut(&mut self) -> &mut Self::Target {
91 &mut self.0
92 }
93}
94
95impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
96 fn from(value: ChatCompletionToolChoiceOption) -> Self {
97 Self(value)
98 }
99}
100
101impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
102 fn from(value: LLMToolChoice) -> Self {
103 value.0
104 }
105}
106
107#[derive(Args, Clone, Debug)]
108pub struct LLMSettings {
109 #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
110 pub llm_temperature: f32,
111
112 #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
113 pub llm_presence_penalty: f32,
114
115 #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
116 pub llm_prompt_timeout: u64,
117
118 #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
119 pub llm_retry: u64,
120
121 #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
122 pub llm_max_completion_tokens: u32,
123
124 #[arg(long, env = "LLM_TOOL_CHOINCE")]
125 pub llm_tool_choice: Option<LLMToolChoice>,
126
127 #[arg(
128 long,
129 env = "LLM_STREAM",
130 default_value_t = false,
131 value_parser = clap::builder::BoolishValueParser::new()
132 )]
133 pub llm_stream: bool,
134}
135
136#[derive(Args, Clone, Debug)]
137pub struct OpenAISetup {
138 #[arg(
139 long,
140 env = "OPENAI_API_URL",
141 default_value = "https://api.openai.com/v1"
142 )]
143 pub openai_url: String,
144
145 #[arg(long, env = "AZURE_OPENAI_ENDPOINT")]
146 pub azure_openai_endpoint: Option<String>,
147
148 #[arg(long, env = "OPENAI_API_KEY")]
149 pub openai_key: Option<String>,
150
151 #[arg(long, env = "AZURE_API_DEPLOYMENT")]
152 pub azure_deployment: Option<String>,
153
154 #[arg(long, env = "AZURE_API_VERSION", default_value = "2025-01-01-preview")]
155 pub azure_api_version: String,
156
157 #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
158 pub biling_cap: f64,
159
160 #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
161 pub model: OpenAIModel,
162
163 #[arg(long, env = "LLM_DEBUG")]
164 pub llm_debug: Option<PathBuf>,
165
166 #[clap(flatten)]
167 pub llm_settings: LLMSettings,
168}
169
170impl OpenAISetup {
171 pub fn to_config(&self) -> SupportedConfig {
172 if let Some(ep) = self.azure_openai_endpoint.as_ref() {
173 let cfg = AzureConfig::new()
174 .with_api_base(ep)
175 .with_api_key(self.openai_key.clone().unwrap_or_default())
176 .with_deployment_id(
177 self.azure_deployment
178 .as_ref()
179 .unwrap_or(&self.model.to_string()),
180 )
181 .with_api_version(&self.azure_api_version);
182 SupportedConfig::Azure(cfg)
183 } else {
184 let cfg = OpenAIConfig::new()
185 .with_api_base(&self.openai_url)
186 .with_api_key(self.openai_key.clone().unwrap_or_default());
187 SupportedConfig::OpenAI(cfg)
188 }
189 }
190
191 pub fn to_llm(&self) -> LLM {
192 let billing = RwLock::new(ModelBilling::new(self.biling_cap));
193
194 let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
195 let pid = std::process::id();
196
197 let mut cnt = 0u64;
198 let debug_path;
199 loop {
200 let test_path = dbg.join(format!("{}-{}", pid, cnt));
201 if !test_path.exists() {
202 std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
203 debug_path = Some(test_path);
204 debug!("The path to save LLM interactions is {:?}", &debug_path);
205 break;
206 } else {
207 cnt += 1;
208 }
209 }
210 debug_path
211 } else {
212 None
213 };
214
215 LLM {
216 llm: Arc::new(LLMInner {
217 client: LLMClient::new(self.to_config()),
218 model: self.model.clone(),
219 billing,
220 llm_debug: debug_path,
221 llm_debug_index: AtomicU64::new(0),
222 default_settings: self.llm_settings.clone(),
223 }),
224 }
225 }
226}
227
228#[derive(Debug, Clone)]
229pub enum SupportedConfig {
230 Azure(AzureConfig),
231 OpenAI(OpenAIConfig),
232}
233
234#[derive(Debug, Clone)]
235pub enum LLMClient {
236 Azure(Client<AzureConfig>),
237 OpenAI(Client<OpenAIConfig>),
238}
239
240impl LLMClient {
241 pub fn new(config: SupportedConfig) -> Self {
242 match config {
243 SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
244 SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
245 }
246 }
247
248 pub async fn create_chat(
249 &self,
250 req: CreateChatCompletionRequest,
251 ) -> Result<CreateChatCompletionResponse, OpenAIError> {
252 match self {
253 Self::Azure(cl) => cl.chat().create(req).await,
254 Self::OpenAI(cl) => cl.chat().create(req).await,
255 }
256 }
257
258 pub async fn create_chat_stream(
259 &self,
260 req: CreateChatCompletionRequest,
261 ) -> Result<ChatCompletionResponseStream, OpenAIError> {
262 match self {
263 Self::Azure(cl) => cl.chat().create_stream(req).await,
264 Self::OpenAI(cl) => cl.chat().create_stream(req).await,
265 }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ModelBilling {
271 pub current: f64,
272 pub cap: f64,
273}
274
275impl Display for ModelBilling {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
278 }
279}
280
281impl ModelBilling {
282 pub fn new(cap: f64) -> Self {
283 Self { current: 0.0, cap }
284 }
285
286 pub fn in_cap(&self) -> bool {
287 self.current <= self.cap
288 }
289
290 pub fn input_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
291 let pricing = model.pricing();
292
293 self.current += (pricing.input_tokens * (count as f64)) / 1e6;
294
295 if self.in_cap() {
296 Ok(())
297 } else {
298 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
299 }
300 }
301
302 pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
303 let pricing = model.pricing();
304
305 self.current += pricing.output_tokens * (count as f64) / 1e6;
306
307 if self.in_cap() {
308 Ok(())
309 } else {
310 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
311 }
312 }
313}
314
315#[derive(Debug, Clone)]
316pub struct LLM {
317 pub llm: Arc<LLMInner>,
318}
319
320impl Deref for LLM {
321 type Target = LLMInner;
322
323 fn deref(&self) -> &Self::Target {
324 &self.llm
325 }
326}
327
328#[derive(Debug)]
329pub struct LLMInner {
330 pub client: LLMClient,
331 pub model: OpenAIModel,
332 pub billing: RwLock<ModelBilling>,
333 pub llm_debug: Option<PathBuf>,
334 pub llm_debug_index: AtomicU64,
335 pub default_settings: LLMSettings,
336}
337
338pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
339 match msg {
340 ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
341 ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
342 ChatCompletionRequestMessage::Function(_) => "FUNCTION",
343 ChatCompletionRequestMessage::System(_) => "SYSTEM",
344 ChatCompletionRequestMessage::Tool(_) => "TOOL",
345 ChatCompletionRequestMessage::User(_) => "USER",
346 }
347}
348
349pub fn toolcall_to_string(t: &ChatCompletionMessageToolCalls) -> String {
350 match t {
351 ChatCompletionMessageToolCalls::Function(t) => {
352 format!(
353 "<toolcall name=\"{}\">\n{}\n</toolcall>",
354 &t.function.name, &t.function.arguments
355 )
356 }
357 ChatCompletionMessageToolCalls::Custom(t) => {
358 format!(
359 "<customtoolcall name=\"{}\">\n{}\n</customtoolcall>",
360 &t.custom_tool.name, &t.custom_tool.input
361 )
362 }
363 }
364}
365
366pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
367 let mut s = String::new();
368 if let Some(content) = resp.content.as_ref() {
369 s += content;
370 s += "\n";
371 }
372
373 if let Some(tools) = resp.tool_calls.as_ref() {
374 s += &tools.iter().map(|t| toolcall_to_string(t)).join("\n");
375 }
376
377 if let Some(refusal) = &resp.refusal {
378 s += refusal;
379 s += "\n";
380 }
381
382 let role = resp.role.to_string().to_uppercase();
383
384 format!("<{}>\n{}\n</{}>\n", &role, s, &role)
385}
386
387pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
388 const CONT: &str = "<cont/>\n";
389 const NONE: &str = "<none/>\n";
390 let role = completion_to_role(msg);
391 let content = match msg {
392 ChatCompletionRequestMessage::Assistant(ass) => {
393 let msg = ass
394 .content
395 .as_ref()
396 .map(|ass| match ass {
397 ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
398 ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
399 .iter()
400 .map(|v| match v {
401 ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
402 s.text.clone()
403 }
404 ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
405 rf.refusal.clone()
406 }
407 })
408 .join(CONT),
409 })
410 .unwrap_or(NONE.to_string());
411 let tool_calls = ass
412 .tool_calls
413 .iter()
414 .flatten()
415 .map(|t| toolcall_to_string(t))
416 .join("\n");
417 format!("{}\n{}", msg, tool_calls)
418 }
419 ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
420 ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
421 ChatCompletionRequestDeveloperMessageContent::Array(arr) => arr
422 .iter()
423 .map(|v| match v {
424 ChatCompletionRequestDeveloperMessageContentPart::Text(v) => v.text.clone(),
425 })
426 .join(CONT),
427 },
428 ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
429 ChatCompletionRequestMessage::System(sys) => match &sys.content {
430 ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
431 ChatCompletionRequestSystemMessageContent::Array(arr) => arr
432 .iter()
433 .map(|v| match v {
434 ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
435 })
436 .join(CONT),
437 },
438 ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
439 ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
440 ChatCompletionRequestToolMessageContent::Array(arr) => arr
441 .iter()
442 .map(|v| match v {
443 ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
444 })
445 .join(CONT),
446 },
447 ChatCompletionRequestMessage::User(usr) => match &usr.content {
448 ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
449 ChatCompletionRequestUserMessageContent::Array(arr) => arr
450 .iter()
451 .map(|v| match v {
452 ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
453 ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
454 format!("<img url=\"{}\"/>", &img.image_url.url)
455 }
456 ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
457 format!("<audio>{}</audio>", audio.input_audio.data)
458 }
459 ChatCompletionRequestUserMessageContentPart::File(f) => {
460 format!("<file>{:?}</file>", f)
461 }
462 })
463 .join(CONT),
464 },
465 };
466
467 format!("<{}>\n{}\n</{}>\n", role, content, role)
468}
469
470impl LLMInner {
471 async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
472 let mut json_fp = fpath.to_path_buf();
473 json_fp.set_file_name(format!(
474 "{}.json",
475 json_fp
476 .file_stem()
477 .ok_or_eyre(eyre!("no filename"))?
478 .to_str()
479 .ok_or_eyre(eyre!("non-utf fname"))?
480 ));
481
482 let mut fp = tokio::fs::OpenOptions::new()
483 .create(true)
484 .append(true)
485 .write(true)
486 .open(&json_fp)
487 .await?;
488 let s = match serde_json::to_string(&t) {
489 Ok(s) => s,
490 Err(_) => format!("{:?}", &t),
491 };
492 fp.write_all(s.as_bytes()).await?;
493 fp.write_all(b"\n").await?;
494 fp.flush().await?;
495
496 Ok(())
497 }
498
499 async fn save_llm_user(
500 fpath: &PathBuf,
501 user_msg: &CreateChatCompletionRequest,
502 ) -> Result<(), PromptError> {
503 let mut fp = tokio::fs::OpenOptions::new()
504 .create(true)
505 .truncate(true)
506 .write(true)
507 .open(&fpath)
508 .await?;
509 fp.write_all(b"=====================\n<Request>\n").await?;
510 for it in user_msg.messages.iter() {
511 let msg = completion_to_string(it);
512 fp.write_all(msg.as_bytes()).await?;
513 }
514
515 let mut tools = vec![];
516 for tool in user_msg
517 .tools
518 .as_ref()
519 .map(|t| t.iter())
520 .into_iter()
521 .flatten()
522 {
523 let s = match tool {
524 ChatCompletionTools::Function(tool) => {
525 format!(
526 "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
527 &tool.function.name,
528 &tool.function.description.clone().unwrap_or_default(),
529 tool.function.strict.unwrap_or_default(),
530 tool.function
531 .parameters
532 .as_ref()
533 .map(serde_json::to_string_pretty)
534 .transpose()?
535 .unwrap_or_default()
536 )
537 }
538 ChatCompletionTools::Custom(tool) => {
539 format!(
540 "<customtool name=\"{}\", description=\"{:?}\"></customtool>",
541 tool.custom.name, tool.custom.description
542 )
543 }
544 };
545 tools.push(s);
546 }
547 fp.write_all(tools.join("\n").as_bytes()).await?;
548 fp.write_all(b"\n</Request>\n=====================\n")
549 .await?;
550 fp.flush().await?;
551
552 Self::rewrite_json(fpath, user_msg).await?;
553
554 Ok(())
555 }
556
557 async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
558 let mut fp = tokio::fs::OpenOptions::new()
559 .create(false)
560 .append(true)
561 .write(true)
562 .open(&fpath)
563 .await?;
564 fp.write_all(b"=====================\n<Response>\n").await?;
565 for it in &resp.choices {
566 let msg = response_to_string(&it.message);
567 fp.write_all(msg.as_bytes()).await?;
568 }
569 fp.write_all(b"\n</Response>\n=====================\n")
570 .await?;
571 fp.flush().await?;
572
573 Self::rewrite_json(fpath, resp).await?;
574
575 Ok(())
576 }
577
578 fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
579 if let Some(output_folder) = self.llm_debug.as_ref() {
580 let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
581 let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
582 Some(fpath)
583 } else {
584 None
585 }
586 }
587
588 pub async fn prompt_once_with_retry(
590 &self,
591 sys_msg: &str,
592 user_msg: &str,
593 prefix: Option<&str>,
594 settings: Option<LLMSettings>,
595 ) -> Result<CreateChatCompletionResponse, PromptError> {
596 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
597 let sys = ChatCompletionRequestSystemMessageArgs::default()
598 .content(sys_msg)
599 .build()?;
600
601 let user = ChatCompletionRequestUserMessageArgs::default()
602 .content(user_msg)
603 .build()?;
604 let mut req = CreateChatCompletionRequestArgs::default();
605 req.messages(vec![sys.into(), user.into()])
606 .model(self.model.to_string())
607 .temperature(settings.llm_temperature)
608 .presence_penalty(settings.llm_presence_penalty)
609 .max_completion_tokens(settings.llm_max_completion_tokens);
610
611 if let Some(tc) = settings.llm_tool_choice {
612 req.tool_choice(tc);
613 }
614 let req = req.build()?;
615
616 let timeout = if settings.llm_prompt_timeout == 0 {
617 Duration::MAX
618 } else {
619 Duration::from_secs(settings.llm_prompt_timeout)
620 };
621
622 self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
623 .await
624 }
625
626 pub async fn complete_once_with_retry(
627 &self,
628 req: &CreateChatCompletionRequest,
629 prefix: Option<&str>,
630 timeout: Option<Duration>,
631 retry: Option<u64>,
632 ) -> Result<CreateChatCompletionResponse, PromptError> {
633 let timeout = if let Some(timeout) = timeout {
634 timeout
635 } else {
636 Duration::MAX
637 };
638
639 let retry = if let Some(retry) = retry {
640 retry
641 } else {
642 u64::MAX
643 };
644
645 let mut last = None;
646 for idx in 0..retry {
647 match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
648 Ok(r) => {
649 last = Some(r);
650 }
651 Err(_) => {
652 warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
653 continue;
654 }
655 };
656
657 match last {
658 Some(Ok(r)) => return Ok(r),
659 Some(Err(ref e)) => {
660 warn!(
661 "Having an error {} during {} retry (timeout is {:?})",
662 e, idx, timeout
663 );
664 }
665 _ => {}
666 }
667 }
668
669 last.ok_or_eyre(eyre!("retry is zero?!"))
670 .map_err(PromptError::Other)?
671 }
672
673 pub async fn complete(
674 &self,
675 req: CreateChatCompletionRequest,
676 prefix: Option<&str>,
677 ) -> Result<CreateChatCompletionResponse, PromptError> {
678 let use_stream = self.default_settings.llm_stream;
679 let prefix = if let Some(prefix) = prefix {
680 prefix.to_string()
681 } else {
682 "llm".to_string()
683 };
684 let debug_fp = self.on_llm_debug(&prefix);
685
686 if let Some(debug_fp) = debug_fp.as_ref() {
687 if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
688 warn!("Fail to save user due to {}", e);
689 }
690 }
691
692 trace!(
693 "Sending completion request: {:?}",
694 &serde_json::to_string(&req)
695 );
696 let resp = if use_stream {
697 self.complete_streaming(req).await?
698 } else {
699 self.client.create_chat(req).await?
700 };
701
702 if let Some(debug_fp) = debug_fp.as_ref() {
703 if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
704 warn!("Fail to save resp due to {}", e);
705 }
706 }
707
708 if let Some(usage) = &resp.usage {
709 self.billing
710 .write()
711 .await
712 .input_tokens(&self.model, usage.prompt_tokens as u64)
713 .map_err(PromptError::Other)?;
714 self.billing
715 .write()
716 .await
717 .output_tokens(&self.model, usage.completion_tokens as u64)
718 .map_err(PromptError::Other)?;
719 } else {
720 warn!("No usage?!")
721 }
722
723 info!("Model Billing: {}", &self.billing.read().await);
724 Ok(resp)
725 }
726
727 async fn complete_streaming(
728 &self,
729 mut req: CreateChatCompletionRequest,
730 ) -> Result<CreateChatCompletionResponse, PromptError> {
731 if req.stream_options.is_none() {
732 req.stream_options = Some(ChatCompletionStreamOptions {
733 include_usage: Some(true),
734 include_obfuscation: None,
735 });
736 }
737
738 let mut stream = self.client.create_chat_stream(req).await?;
739
740 let mut id: Option<String> = None;
741 let mut created: Option<u32> = None;
742 let mut model: Option<String> = None;
743 let mut service_tier = None;
744 let mut system_fingerprint = None;
745 let mut usage: Option<CompletionUsage> = None;
746
747 let mut contents: Vec<String> = Vec::new();
748 let mut finish_reasons: Vec<Option<FinishReason>> = Vec::new();
749 let mut tool_calls: Vec<Vec<ToolCallAcc>> = Vec::new();
750
751 while let Some(item) = stream.next().await {
752 let chunk: CreateChatCompletionStreamResponse = item?;
753 if id.is_none() {
754 id = Some(chunk.id.clone());
755 }
756 created = Some(chunk.created);
757 model = Some(chunk.model.clone());
758 service_tier = chunk.service_tier.clone();
759 system_fingerprint = chunk.system_fingerprint.clone();
760 if let Some(u) = chunk.usage.clone() {
761 usage = Some(u);
762 }
763
764 for ch in chunk.choices.into_iter() {
765 let idx = ch.index as usize;
766 if contents.len() <= idx {
767 contents.resize_with(idx + 1, String::new);
768 finish_reasons.resize_with(idx + 1, || None);
769 tool_calls.resize_with(idx + 1, Vec::new);
770 }
771 if let Some(delta) = ch.delta.content {
772 contents[idx].push_str(&delta);
773 }
774 if let Some(tcs) = ch.delta.tool_calls {
775 for tc in tcs.into_iter() {
776 let tc_idx = tc.index as usize;
777 if tool_calls[idx].len() <= tc_idx {
778 tool_calls[idx].resize_with(tc_idx + 1, ToolCallAcc::default);
779 }
780 let acc = &mut tool_calls[idx][tc_idx];
781 if let Some(id) = tc.id {
782 acc.id = id;
783 }
784 if let Some(func) = tc.function {
785 if let Some(name) = func.name {
786 acc.name = name;
787 }
788 if let Some(args) = func.arguments {
789 acc.arguments.push_str(&args);
790 }
791 }
792 }
793 }
794 if ch.finish_reason.is_some() {
795 finish_reasons[idx] = ch.finish_reason;
796 }
797 }
798 }
799
800 let mut choices = Vec::new();
801 for (idx, content) in contents.into_iter().enumerate() {
802 let finish_reason = finish_reasons.get(idx).cloned().unwrap_or(None);
803 let built_tool_calls = tool_calls
804 .get(idx)
805 .cloned()
806 .unwrap_or_default()
807 .into_iter()
808 .filter(|t| !t.name.trim().is_empty() || !t.arguments.trim().is_empty())
809 .map(|t| {
810 ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
811 id: if t.id.trim().is_empty() {
812 format!("toolcall-{}", idx)
813 } else {
814 t.id
815 },
816 function: FunctionCall {
817 name: t.name,
818 arguments: t.arguments,
819 },
820 })
821 })
822 .collect::<Vec<_>>();
823 let tool_calls_opt = if built_tool_calls.is_empty() {
824 None
825 } else {
826 Some(built_tool_calls)
827 };
828 choices.push(ChatChoice {
829 index: idx as u32,
830 message: ChatCompletionResponseMessage {
831 content: if content.is_empty() {
832 None
833 } else {
834 Some(content)
835 },
836 refusal: None,
837 tool_calls: tool_calls_opt,
838 annotations: None,
839 role: Role::Assistant,
840 function_call: None,
841 audio: None,
842 },
843 finish_reason,
844 logprobs: None,
845 });
846 }
847 if choices.is_empty() {
848 choices.push(ChatChoice {
849 index: 0,
850 message: ChatCompletionResponseMessage {
851 content: Some(String::new()),
852 refusal: None,
853 tool_calls: None,
854 annotations: None,
855 role: Role::Assistant,
856 function_call: None,
857 audio: None,
858 },
859 finish_reason: None,
860 logprobs: None,
861 });
862 }
863
864 Ok(CreateChatCompletionResponse {
865 id: id.unwrap_or_else(|| "stream".to_string()),
866 choices,
867 created: created.unwrap_or(0),
868 model: model.unwrap_or_else(|| self.model.to_string()),
869 service_tier,
870 system_fingerprint,
871 object: "chat.completion".to_string(),
872 usage,
873 })
874 }
875
876 pub async fn prompt_once(
877 &self,
878 sys_msg: &str,
879 user_msg: &str,
880 prefix: Option<&str>,
881 settings: Option<LLMSettings>,
882 ) -> Result<CreateChatCompletionResponse, PromptError> {
883 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
884 let sys = ChatCompletionRequestSystemMessageArgs::default()
885 .content(sys_msg)
886 .build()?;
887
888 let user = ChatCompletionRequestUserMessageArgs::default()
889 .content(user_msg)
890 .build()?;
891 let req = CreateChatCompletionRequestArgs::default()
892 .messages(vec![sys.into(), user.into()])
893 .model(self.model.to_string())
894 .temperature(settings.llm_temperature)
895 .presence_penalty(settings.llm_presence_penalty)
896 .max_completion_tokens(settings.llm_max_completion_tokens)
897 .build()?;
898 self.complete(req, prefix).await
899 }
900}