1use std::{
2 fmt::{Debug, Display},
3 ops::{Deref, DerefMut},
4 path::{Path, PathBuf},
5 str::FromStr,
6 sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9 },
10 time::Duration,
11};
12
13use async_openai::{
14 Client,
15 config::{AzureConfig, OpenAIConfig},
16 error::OpenAIError,
17 types::chat::{
18 ChatChoice, ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
19 ChatCompletionNamedToolChoiceCustom, ChatCompletionRequestAssistantMessageContent,
20 ChatCompletionRequestAssistantMessageContentPart,
21 ChatCompletionRequestDeveloperMessageContent,
22 ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
23 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
24 ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
25 ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
26 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
27 ChatCompletionResponseMessage, ChatCompletionResponseStream, ChatCompletionStreamOptions,
28 ChatCompletionToolChoiceOption, ChatCompletionTools, CompletionUsage,
29 CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
30 CreateChatCompletionStreamResponse, CustomName, FinishReason, FunctionCall, Role,
31 ToolChoiceOptions,
32 },
33};
34use clap::Args;
35use color_eyre::{
36 Result,
37 eyre::{OptionExt, eyre},
38};
39use futures_util::StreamExt;
40use itertools::Itertools;
41use log::{debug, info, trace, warn};
42use serde::{Deserialize, Serialize};
43use tokio::{io::AsyncWriteExt, sync::RwLock};
44
45use crate::{OpenAIModel, error::PromptError};
46
47#[derive(Clone, Debug, Default)]
48struct ToolCallAcc {
49 id: String,
50 name: String,
51 arguments: String,
52}
53
54#[derive(Debug, Clone)]
56pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
57
58impl FromStr for LLMToolChoice {
59 type Err = PromptError;
60 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
61 Ok(match s {
62 "auto" => Self(ChatCompletionToolChoiceOption::Mode(
63 ToolChoiceOptions::Auto,
64 )),
65 "required" => Self(ChatCompletionToolChoiceOption::Mode(
66 ToolChoiceOptions::Required,
67 )),
68 "none" => Self(ChatCompletionToolChoiceOption::Mode(
69 ToolChoiceOptions::None,
70 )),
71 _ => Self(ChatCompletionToolChoiceOption::Custom(
72 ChatCompletionNamedToolChoiceCustom {
73 custom: CustomName {
74 name: s.to_string(),
75 },
76 },
77 )),
78 })
79 }
80}
81
82impl Deref for LLMToolChoice {
83 type Target = ChatCompletionToolChoiceOption;
84 fn deref(&self) -> &Self::Target {
85 &self.0
86 }
87}
88
89impl DerefMut for LLMToolChoice {
90 fn deref_mut(&mut self) -> &mut Self::Target {
91 &mut self.0
92 }
93}
94
95impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
96 fn from(value: ChatCompletionToolChoiceOption) -> Self {
97 Self(value)
98 }
99}
100
101impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
102 fn from(value: LLMToolChoice) -> Self {
103 value.0
104 }
105}
106
107#[derive(Args, Clone, Debug)]
108pub struct LLMSettings {
109 #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
110 pub llm_temperature: f32,
111
112 #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
113 pub llm_presence_penalty: f32,
114
115 #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
116 pub llm_prompt_timeout: u64,
117
118 #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
119 pub llm_retry: u64,
120
121 #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
122 pub llm_max_completion_tokens: u32,
123
124 #[arg(long, env = "LLM_TOOL_CHOINCE")]
125 pub llm_tool_choice: Option<LLMToolChoice>,
126
127 #[arg(
128 long,
129 env = "LLM_STREAM",
130 default_value_t = false,
131 value_parser = clap::builder::BoolishValueParser::new()
132 )]
133 pub llm_stream: bool,
134}
135
136#[derive(Args, Clone, Debug)]
137pub struct OpenAISetup {
138 #[arg(
139 long,
140 env = "OPENAI_API_URL",
141 default_value = "https://api.openai.com/v1"
142 )]
143 pub openai_url: String,
144
145 #[arg(long, env = "AZURE_OPENAI_ENDPOINT")]
146 pub azure_openai_endpoint: Option<String>,
147
148 #[arg(long, env = "OPENAI_API_KEY")]
149 pub openai_key: Option<String>,
150
151 #[arg(long, env = "AZURE_API_DEPLOYMENT")]
152 pub azure_deployment: Option<String>,
153
154 #[arg(long, env = "AZURE_API_VERSION", default_value = "2025-01-01-preview")]
155 pub azure_api_version: String,
156
157 #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
158 pub biling_cap: f64,
159
160 #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
161 pub model: OpenAIModel,
162
163 #[arg(long, env = "LLM_DEBUG")]
164 pub llm_debug: Option<PathBuf>,
165
166 #[clap(flatten)]
167 pub llm_settings: LLMSettings,
168}
169
170impl OpenAISetup {
171 pub fn to_config(&self) -> SupportedConfig {
172 if let Some(ep) = self.azure_openai_endpoint.as_ref() {
173 let cfg = AzureConfig::new()
174 .with_api_base(ep)
175 .with_api_key(self.openai_key.clone().unwrap_or_default())
176 .with_deployment_id(
177 self.azure_deployment
178 .as_ref()
179 .unwrap_or(&self.model.to_string()),
180 )
181 .with_api_version(&self.azure_api_version);
182 SupportedConfig::Azure(cfg)
183 } else {
184 let cfg = OpenAIConfig::new()
185 .with_api_base(&self.openai_url)
186 .with_api_key(self.openai_key.clone().unwrap_or_default());
187 SupportedConfig::OpenAI(cfg)
188 }
189 }
190
191 pub fn to_llm(&self) -> LLM {
192 let billing = RwLock::new(ModelBilling::new(self.biling_cap));
193
194 let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
195 let pid = std::process::id();
196
197 let mut cnt = 0u64;
198 let debug_path;
199 loop {
200 let test_path = dbg.join(format!("{}-{}", pid, cnt));
201 if !test_path.exists() {
202 std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
203 debug_path = Some(test_path);
204 debug!("The path to save LLM interactions is {:?}", &debug_path);
205 break;
206 } else {
207 cnt += 1;
208 }
209 }
210 debug_path
211 } else {
212 None
213 };
214
215 LLM {
216 llm: Arc::new(LLMInner {
217 client: LLMClient::new(self.to_config()),
218 model: self.model.clone(),
219 billing,
220 llm_debug: debug_path,
221 llm_debug_index: AtomicU64::new(0),
222 default_settings: self.llm_settings.clone(),
223 }),
224 }
225 }
226}
227
228#[derive(Debug, Clone)]
229pub enum SupportedConfig {
230 Azure(AzureConfig),
231 OpenAI(OpenAIConfig),
232}
233
234#[derive(Debug, Clone)]
235pub enum LLMClient {
236 Azure(Client<AzureConfig>),
237 OpenAI(Client<OpenAIConfig>),
238}
239
240impl LLMClient {
241 pub fn new(config: SupportedConfig) -> Self {
242 match config {
243 SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
244 SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
245 }
246 }
247
248 pub async fn create_chat(
249 &self,
250 req: CreateChatCompletionRequest,
251 ) -> Result<CreateChatCompletionResponse, OpenAIError> {
252 match self {
253 Self::Azure(cl) => cl.chat().create(req).await,
254 Self::OpenAI(cl) => cl.chat().create(req).await,
255 }
256 }
257
258 pub async fn create_chat_stream(
259 &self,
260 req: CreateChatCompletionRequest,
261 ) -> Result<ChatCompletionResponseStream, OpenAIError> {
262 match self {
263 Self::Azure(cl) => cl.chat().create_stream(req).await,
264 Self::OpenAI(cl) => cl.chat().create_stream(req).await,
265 }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ModelBilling {
271 pub current: f64,
272 pub cap: f64,
273}
274
275impl Display for ModelBilling {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
278 }
279}
280
281impl ModelBilling {
282 pub fn new(cap: f64) -> Self {
283 Self { current: 0.0, cap }
284 }
285
286 pub fn in_cap(&self) -> bool {
287 self.current <= self.cap
288 }
289
290 pub fn input_tokens(
291 &mut self,
292 model: &OpenAIModel,
293 count: u64,
294 cached_count: u64,
295 ) -> Result<()> {
296 let pricing = model.pricing();
297
298 let cached_price = if let Some(cached) = pricing.cached_input_tokens {
299 cached
300 } else {
301 pricing.input_tokens
302 };
303
304 self.current += (cached_price * (cached_count as f64)) / 1e6;
305 self.current += (pricing.input_tokens * (count as f64)) / 1e6;
306
307 if self.in_cap() {
308 Ok(())
309 } else {
310 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
311 }
312 }
313
314 pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
315 let pricing = model.pricing();
316
317 self.current += pricing.output_tokens * (count as f64) / 1e6;
318
319 if self.in_cap() {
320 Ok(())
321 } else {
322 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
323 }
324 }
325}
326
327#[derive(Debug, Clone)]
328pub struct LLM {
329 pub llm: Arc<LLMInner>,
330}
331
332impl Deref for LLM {
333 type Target = LLMInner;
334
335 fn deref(&self) -> &Self::Target {
336 &self.llm
337 }
338}
339
340#[derive(Debug)]
341pub struct LLMInner {
342 pub client: LLMClient,
343 pub model: OpenAIModel,
344 pub billing: RwLock<ModelBilling>,
345 pub llm_debug: Option<PathBuf>,
346 pub llm_debug_index: AtomicU64,
347 pub default_settings: LLMSettings,
348}
349
350pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
351 match msg {
352 ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
353 ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
354 ChatCompletionRequestMessage::Function(_) => "FUNCTION",
355 ChatCompletionRequestMessage::System(_) => "SYSTEM",
356 ChatCompletionRequestMessage::Tool(_) => "TOOL",
357 ChatCompletionRequestMessage::User(_) => "USER",
358 }
359}
360
361pub fn toolcall_to_string(t: &ChatCompletionMessageToolCalls) -> String {
362 match t {
363 ChatCompletionMessageToolCalls::Function(t) => {
364 format!(
365 "<toolcall name=\"{}\">\n{}\n</toolcall>",
366 &t.function.name, &t.function.arguments
367 )
368 }
369 ChatCompletionMessageToolCalls::Custom(t) => {
370 format!(
371 "<customtoolcall name=\"{}\">\n{}\n</customtoolcall>",
372 &t.custom_tool.name, &t.custom_tool.input
373 )
374 }
375 }
376}
377
378pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
379 let mut s = String::new();
380 if let Some(content) = resp.content.as_ref() {
381 s += content;
382 s += "\n";
383 }
384
385 if let Some(tools) = resp.tool_calls.as_ref() {
386 s += &tools.iter().map(|t| toolcall_to_string(t)).join("\n");
387 }
388
389 if let Some(refusal) = &resp.refusal {
390 s += refusal;
391 s += "\n";
392 }
393
394 let role = resp.role.to_string().to_uppercase();
395
396 format!("<{}>\n{}\n</{}>\n", &role, s, &role)
397}
398
399pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
400 const CONT: &str = "<cont/>\n";
401 const NONE: &str = "<none/>\n";
402 let role = completion_to_role(msg);
403 let content = match msg {
404 ChatCompletionRequestMessage::Assistant(ass) => {
405 let msg = ass
406 .content
407 .as_ref()
408 .map(|ass| match ass {
409 ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
410 ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
411 .iter()
412 .map(|v| match v {
413 ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
414 s.text.clone()
415 }
416 ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
417 rf.refusal.clone()
418 }
419 })
420 .join(CONT),
421 })
422 .unwrap_or(NONE.to_string());
423 let tool_calls = ass
424 .tool_calls
425 .iter()
426 .flatten()
427 .map(|t| toolcall_to_string(t))
428 .join("\n");
429 format!("{}\n{}", msg, tool_calls)
430 }
431 ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
432 ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
433 ChatCompletionRequestDeveloperMessageContent::Array(arr) => arr
434 .iter()
435 .map(|v| match v {
436 ChatCompletionRequestDeveloperMessageContentPart::Text(v) => v.text.clone(),
437 })
438 .join(CONT),
439 },
440 ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
441 ChatCompletionRequestMessage::System(sys) => match &sys.content {
442 ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
443 ChatCompletionRequestSystemMessageContent::Array(arr) => arr
444 .iter()
445 .map(|v| match v {
446 ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
447 })
448 .join(CONT),
449 },
450 ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
451 ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
452 ChatCompletionRequestToolMessageContent::Array(arr) => arr
453 .iter()
454 .map(|v| match v {
455 ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
456 })
457 .join(CONT),
458 },
459 ChatCompletionRequestMessage::User(usr) => match &usr.content {
460 ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
461 ChatCompletionRequestUserMessageContent::Array(arr) => arr
462 .iter()
463 .map(|v| match v {
464 ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
465 ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
466 format!("<img url=\"{}\"/>", &img.image_url.url)
467 }
468 ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
469 format!("<audio>{}</audio>", audio.input_audio.data)
470 }
471 ChatCompletionRequestUserMessageContentPart::File(f) => {
472 format!("<file>{:?}</file>", f)
473 }
474 })
475 .join(CONT),
476 },
477 };
478
479 format!("<{}>\n{}\n</{}>\n", role, content, role)
480}
481
482impl LLMInner {
483 async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
484 let mut json_fp = fpath.to_path_buf();
485 json_fp.set_file_name(format!(
486 "{}.json",
487 json_fp
488 .file_stem()
489 .ok_or_eyre(eyre!("no filename"))?
490 .to_str()
491 .ok_or_eyre(eyre!("non-utf fname"))?
492 ));
493
494 let mut fp = tokio::fs::OpenOptions::new()
495 .create(true)
496 .append(true)
497 .write(true)
498 .open(&json_fp)
499 .await?;
500 let s = match serde_json::to_string(&t) {
501 Ok(s) => s,
502 Err(_) => format!("{:?}", &t),
503 };
504 fp.write_all(s.as_bytes()).await?;
505 fp.write_all(b"\n").await?;
506 fp.flush().await?;
507
508 Ok(())
509 }
510
511 async fn save_llm_user(
512 fpath: &PathBuf,
513 user_msg: &CreateChatCompletionRequest,
514 ) -> Result<(), PromptError> {
515 let mut fp = tokio::fs::OpenOptions::new()
516 .create(true)
517 .truncate(true)
518 .write(true)
519 .open(&fpath)
520 .await?;
521 fp.write_all(b"=====================\n<Request>\n").await?;
522 for it in user_msg.messages.iter() {
523 let msg = completion_to_string(it);
524 fp.write_all(msg.as_bytes()).await?;
525 }
526
527 let mut tools = vec![];
528 for tool in user_msg
529 .tools
530 .as_ref()
531 .map(|t| t.iter())
532 .into_iter()
533 .flatten()
534 {
535 let s = match tool {
536 ChatCompletionTools::Function(tool) => {
537 format!(
538 "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
539 &tool.function.name,
540 &tool.function.description.clone().unwrap_or_default(),
541 tool.function.strict.unwrap_or_default(),
542 tool.function
543 .parameters
544 .as_ref()
545 .map(serde_json::to_string_pretty)
546 .transpose()?
547 .unwrap_or_default()
548 )
549 }
550 ChatCompletionTools::Custom(tool) => {
551 format!(
552 "<customtool name=\"{}\", description=\"{:?}\"></customtool>",
553 tool.custom.name, tool.custom.description
554 )
555 }
556 };
557 tools.push(s);
558 }
559 fp.write_all(tools.join("\n").as_bytes()).await?;
560 fp.write_all(b"\n</Request>\n=====================\n")
561 .await?;
562 fp.flush().await?;
563
564 Self::rewrite_json(fpath, user_msg).await?;
565
566 Ok(())
567 }
568
569 async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
570 let mut fp = tokio::fs::OpenOptions::new()
571 .create(false)
572 .append(true)
573 .write(true)
574 .open(&fpath)
575 .await?;
576 fp.write_all(b"=====================\n<Response>\n").await?;
577 for it in &resp.choices {
578 let msg = response_to_string(&it.message);
579 fp.write_all(msg.as_bytes()).await?;
580 }
581 fp.write_all(b"\n</Response>\n=====================\n")
582 .await?;
583 fp.flush().await?;
584
585 Self::rewrite_json(fpath, resp).await?;
586
587 Ok(())
588 }
589
590 fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
591 if let Some(output_folder) = self.llm_debug.as_ref() {
592 let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
593 let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
594 Some(fpath)
595 } else {
596 None
597 }
598 }
599
600 pub async fn prompt_once_with_retry(
602 &self,
603 sys_msg: &str,
604 user_msg: &str,
605 prefix: Option<&str>,
606 settings: Option<LLMSettings>,
607 ) -> Result<CreateChatCompletionResponse, PromptError> {
608 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
609 let sys = ChatCompletionRequestSystemMessageArgs::default()
610 .content(sys_msg)
611 .build()?;
612
613 let user = ChatCompletionRequestUserMessageArgs::default()
614 .content(user_msg)
615 .build()?;
616 let mut req = CreateChatCompletionRequestArgs::default();
617 req.messages(vec![sys.into(), user.into()])
618 .model(self.model.to_string())
619 .temperature(settings.llm_temperature)
620 .presence_penalty(settings.llm_presence_penalty)
621 .max_completion_tokens(settings.llm_max_completion_tokens);
622
623 if let Some(tc) = settings.llm_tool_choice {
624 req.tool_choice(tc);
625 }
626 let req = req.build()?;
627
628 let timeout = if settings.llm_prompt_timeout == 0 {
629 Duration::MAX
630 } else {
631 Duration::from_secs(settings.llm_prompt_timeout)
632 };
633
634 self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
635 .await
636 }
637
638 pub async fn complete_once_with_retry(
639 &self,
640 req: &CreateChatCompletionRequest,
641 prefix: Option<&str>,
642 timeout: Option<Duration>,
643 retry: Option<u64>,
644 ) -> Result<CreateChatCompletionResponse, PromptError> {
645 let timeout = if let Some(timeout) = timeout {
646 timeout
647 } else {
648 Duration::MAX
649 };
650
651 let retry = if let Some(retry) = retry {
652 retry
653 } else {
654 u64::MAX
655 };
656
657 let mut last = None;
658 for idx in 0..retry {
659 match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
660 Ok(r) => {
661 last = Some(r);
662 }
663 Err(_) => {
664 warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
665 continue;
666 }
667 };
668
669 match last {
670 Some(Ok(r)) => return Ok(r),
671 Some(Err(ref e)) => {
672 warn!(
673 "Having an error {} during {} retry (timeout is {:?})",
674 e, idx, timeout
675 );
676 }
677 _ => {}
678 }
679 }
680
681 last.ok_or_eyre(eyre!("retry is zero?!"))
682 .map_err(PromptError::Other)?
683 }
684
685 pub async fn complete(
686 &self,
687 req: CreateChatCompletionRequest,
688 prefix: Option<&str>,
689 ) -> Result<CreateChatCompletionResponse, PromptError> {
690 let use_stream = self.default_settings.llm_stream;
691 let prefix = if let Some(prefix) = prefix {
692 prefix.to_string()
693 } else {
694 "llm".to_string()
695 };
696 let debug_fp = self.on_llm_debug(&prefix);
697
698 if let Some(debug_fp) = debug_fp.as_ref() {
699 if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
700 warn!("Fail to save user due to {}", e);
701 }
702 }
703
704 trace!(
705 "Sending completion request: {:?}",
706 &serde_json::to_string(&req)
707 );
708 let resp = if use_stream {
709 self.complete_streaming(req).await?
710 } else {
711 self.client.create_chat(req).await?
712 };
713
714 if let Some(debug_fp) = debug_fp.as_ref() {
715 if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
716 warn!("Fail to save resp due to {}", e);
717 }
718 }
719
720 if let Some(usage) = &resp.usage {
721 let cached = usage
722 .prompt_tokens_details
723 .as_ref()
724 .map(|v| v.cached_tokens)
725 .flatten()
726 .unwrap_or_default();
727 let input = usage.prompt_tokens - cached;
728 self.billing
729 .write()
730 .await
731 .input_tokens(&self.model, input as _, cached as _)
732 .map_err(PromptError::Other)?;
733 self.billing
734 .write()
735 .await
736 .output_tokens(&self.model, usage.completion_tokens as u64)
737 .map_err(PromptError::Other)?;
738 } else {
739 warn!("No usage?!")
740 }
741
742 info!("Model Billing: {}", &self.billing.read().await);
743 Ok(resp)
744 }
745
746 async fn complete_streaming(
747 &self,
748 mut req: CreateChatCompletionRequest,
749 ) -> Result<CreateChatCompletionResponse, PromptError> {
750 if req.stream_options.is_none() {
751 req.stream_options = Some(ChatCompletionStreamOptions {
752 include_usage: Some(true),
753 include_obfuscation: None,
754 });
755 }
756
757 let mut stream = self.client.create_chat_stream(req).await?;
758
759 let mut id: Option<String> = None;
760 let mut created: Option<u32> = None;
761 let mut model: Option<String> = None;
762 let mut service_tier = None;
763 let mut system_fingerprint = None;
764 let mut usage: Option<CompletionUsage> = None;
765
766 let mut contents: Vec<String> = Vec::new();
767 let mut finish_reasons: Vec<Option<FinishReason>> = Vec::new();
768 let mut tool_calls: Vec<Vec<ToolCallAcc>> = Vec::new();
769
770 while let Some(item) = stream.next().await {
771 let chunk: CreateChatCompletionStreamResponse = item?;
772 if id.is_none() {
773 id = Some(chunk.id.clone());
774 }
775 created = Some(chunk.created);
776 model = Some(chunk.model.clone());
777 service_tier = chunk.service_tier.clone();
778 system_fingerprint = chunk.system_fingerprint.clone();
779 if let Some(u) = chunk.usage.clone() {
780 usage = Some(u);
781 }
782
783 for ch in chunk.choices.into_iter() {
784 let idx = ch.index as usize;
785 if contents.len() <= idx {
786 contents.resize_with(idx + 1, String::new);
787 finish_reasons.resize_with(idx + 1, || None);
788 tool_calls.resize_with(idx + 1, Vec::new);
789 }
790 if let Some(delta) = ch.delta.content {
791 contents[idx].push_str(&delta);
792 }
793 if let Some(tcs) = ch.delta.tool_calls {
794 for tc in tcs.into_iter() {
795 let tc_idx = tc.index as usize;
796 if tool_calls[idx].len() <= tc_idx {
797 tool_calls[idx].resize_with(tc_idx + 1, ToolCallAcc::default);
798 }
799 let acc = &mut tool_calls[idx][tc_idx];
800 if let Some(id) = tc.id {
801 acc.id = id;
802 }
803 if let Some(func) = tc.function {
804 if let Some(name) = func.name {
805 acc.name = name;
806 }
807 if let Some(args) = func.arguments {
808 acc.arguments.push_str(&args);
809 }
810 }
811 }
812 }
813 if ch.finish_reason.is_some() {
814 finish_reasons[idx] = ch.finish_reason;
815 }
816 }
817 }
818
819 let mut choices = Vec::new();
820 for (idx, content) in contents.into_iter().enumerate() {
821 let finish_reason = finish_reasons.get(idx).cloned().unwrap_or(None);
822 let built_tool_calls = tool_calls
823 .get(idx)
824 .cloned()
825 .unwrap_or_default()
826 .into_iter()
827 .filter(|t| !t.name.trim().is_empty() || !t.arguments.trim().is_empty())
828 .map(|t| {
829 ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
830 id: if t.id.trim().is_empty() {
831 format!("toolcall-{}", idx)
832 } else {
833 t.id
834 },
835 function: FunctionCall {
836 name: t.name,
837 arguments: t.arguments,
838 },
839 })
840 })
841 .collect::<Vec<_>>();
842 let tool_calls_opt = if built_tool_calls.is_empty() {
843 None
844 } else {
845 Some(built_tool_calls)
846 };
847 choices.push(ChatChoice {
848 index: idx as u32,
849 message: ChatCompletionResponseMessage {
850 content: if content.is_empty() {
851 None
852 } else {
853 Some(content)
854 },
855 refusal: None,
856 tool_calls: tool_calls_opt,
857 annotations: None,
858 role: Role::Assistant,
859 function_call: None,
860 audio: None,
861 },
862 finish_reason,
863 logprobs: None,
864 });
865 }
866 if choices.is_empty() {
867 choices.push(ChatChoice {
868 index: 0,
869 message: ChatCompletionResponseMessage {
870 content: Some(String::new()),
871 refusal: None,
872 tool_calls: None,
873 annotations: None,
874 role: Role::Assistant,
875 function_call: None,
876 audio: None,
877 },
878 finish_reason: None,
879 logprobs: None,
880 });
881 }
882
883 Ok(CreateChatCompletionResponse {
884 id: id.unwrap_or_else(|| "stream".to_string()),
885 choices,
886 created: created.unwrap_or(0),
887 model: model.unwrap_or_else(|| self.model.to_string()),
888 service_tier,
889 system_fingerprint,
890 object: "chat.completion".to_string(),
891 usage,
892 })
893 }
894
895 pub async fn prompt_once(
896 &self,
897 sys_msg: &str,
898 user_msg: &str,
899 prefix: Option<&str>,
900 settings: Option<LLMSettings>,
901 ) -> Result<CreateChatCompletionResponse, PromptError> {
902 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
903 let sys = ChatCompletionRequestSystemMessageArgs::default()
904 .content(sys_msg)
905 .build()?;
906
907 let user = ChatCompletionRequestUserMessageArgs::default()
908 .content(user_msg)
909 .build()?;
910 let req = CreateChatCompletionRequestArgs::default()
911 .messages(vec![sys.into(), user.into()])
912 .model(self.model.to_string())
913 .temperature(settings.llm_temperature)
914 .presence_penalty(settings.llm_presence_penalty)
915 .max_completion_tokens(settings.llm_max_completion_tokens)
916 .build()?;
917 self.complete(req, prefix).await
918 }
919}