1use std::{
2 fmt::{Debug, Display},
3 ops::{Deref, DerefMut},
4 path::{Path, PathBuf},
5 str::FromStr,
6 sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9 },
10 time::Duration,
11};
12
13use async_openai::{
14 Client,
15 config::{AzureConfig, OpenAIConfig},
16 error::OpenAIError,
17 types::chat::{
18 ChatChoice, ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
19 ChatCompletionNamedToolChoiceCustom, ChatCompletionRequestAssistantMessageContent,
20 ChatCompletionRequestAssistantMessageContentPart,
21 ChatCompletionRequestDeveloperMessageContent,
22 ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
23 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
24 ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
25 ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
26 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
27 ChatCompletionResponseMessage, ChatCompletionResponseStream, ChatCompletionStreamOptions,
28 ChatCompletionToolChoiceOption, ChatCompletionTools, CompletionUsage,
29 CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
30 CreateChatCompletionStreamResponse, CustomName, FinishReason, FunctionCall, Role,
31 ToolChoiceOptions,
32 },
33};
34use clap::Args;
35use color_eyre::{
36 Result,
37 eyre::{OptionExt, eyre},
38};
39use futures_util::StreamExt;
40use itertools::Itertools;
41use log::{debug, info, trace, warn};
42use serde::{Deserialize, Serialize};
43use tokio::{io::AsyncWriteExt, sync::RwLock};
44
45use crate::{OpenAIModel, error::PromptError};
46
47#[derive(Clone, Debug, Default)]
48struct ToolCallAcc {
49 id: String,
50 name: String,
51 arguments: String,
52}
53
54#[derive(Debug, Clone)]
56pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
57
58impl FromStr for LLMToolChoice {
59 type Err = PromptError;
60 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
61 Ok(match s {
62 "auto" => Self(ChatCompletionToolChoiceOption::Mode(
63 ToolChoiceOptions::Auto,
64 )),
65 "required" => Self(ChatCompletionToolChoiceOption::Mode(
66 ToolChoiceOptions::Required,
67 )),
68 "none" => Self(ChatCompletionToolChoiceOption::Mode(
69 ToolChoiceOptions::None,
70 )),
71 _ => Self(ChatCompletionToolChoiceOption::Custom(
72 ChatCompletionNamedToolChoiceCustom {
73 custom: CustomName {
74 name: s.to_string(),
75 },
76 },
77 )),
78 })
79 }
80}
81
82impl Deref for LLMToolChoice {
83 type Target = ChatCompletionToolChoiceOption;
84 fn deref(&self) -> &Self::Target {
85 &self.0
86 }
87}
88
89impl DerefMut for LLMToolChoice {
90 fn deref_mut(&mut self) -> &mut Self::Target {
91 &mut self.0
92 }
93}
94
95impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
96 fn from(value: ChatCompletionToolChoiceOption) -> Self {
97 Self(value)
98 }
99}
100
101impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
102 fn from(value: LLMToolChoice) -> Self {
103 value.0
104 }
105}
106
107#[derive(Args, Clone, Debug)]
108pub struct LLMSettings {
109 #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
110 pub llm_temperature: f32,
111
112 #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
113 pub llm_presence_penalty: f32,
114
115 #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
116 pub llm_prompt_timeout: u64,
117
118 #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
119 pub llm_retry: u64,
120
121 #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
122 pub llm_max_completion_tokens: u32,
123
124 #[arg(long, env = "LLM_TOOL_CHOINCE")]
125 pub llm_tool_choice: Option<LLMToolChoice>,
126
127 #[arg(
128 long,
129 env = "LLM_STREAM",
130 default_value_t = false,
131 value_parser = clap::builder::BoolishValueParser::new()
132 )]
133 pub llm_stream: bool,
134}
135
136#[derive(Args, Clone, Debug)]
137pub struct OpenAISetup {
138 #[arg(
139 long,
140 env = "OPENAI_API_URL",
141 default_value = "https://api.openai.com/v1"
142 )]
143 pub openai_url: String,
144
145 #[arg(long, env = "AZURE_OPENAI_ENDPOINT")]
146 pub azure_openai_endpoint: Option<String>,
147
148 #[arg(long, env = "OPENAI_API_KEY")]
149 pub openai_key: Option<String>,
150
151 #[arg(long, env = "AZURE_API_DEPLOYMENT")]
152 pub azure_deployment: Option<String>,
153
154 #[arg(long, env = "AZURE_API_VERSION", default_value = "2025-01-01-preview")]
155 pub azure_api_version: String,
156
157 #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
158 pub biling_cap: f64,
159
160 #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
161 pub model: OpenAIModel,
162
163 #[arg(long, env = "LLM_DEBUG")]
164 pub llm_debug: Option<PathBuf>,
165
166 #[clap(flatten)]
167 pub llm_settings: LLMSettings,
168}
169
170impl OpenAISetup {
171 pub fn to_config(&self) -> SupportedConfig {
172 if let Some(ep) = self.azure_openai_endpoint.as_ref() {
173 let cfg = AzureConfig::new()
174 .with_api_base(ep)
175 .with_api_key(self.openai_key.clone().unwrap_or_default())
176 .with_deployment_id(
177 self.azure_deployment
178 .as_ref()
179 .unwrap_or(&self.model.to_string()),
180 )
181 .with_api_version(&self.azure_api_version);
182 SupportedConfig::Azure(cfg)
183 } else {
184 let cfg = OpenAIConfig::new()
185 .with_api_base(&self.openai_url)
186 .with_api_key(self.openai_key.clone().unwrap_or_default());
187 SupportedConfig::OpenAI(cfg)
188 }
189 }
190
191 pub fn to_llm(&self) -> LLM {
192 let billing = RwLock::new(ModelBilling::new(self.biling_cap));
193
194 let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
195 let pid = std::process::id();
196
197 let mut cnt = 0u64;
198 let debug_path;
199 loop {
200 let test_path = dbg.join(format!("{}-{}", pid, cnt));
201 if !test_path.exists() {
202 std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
203 debug_path = Some(test_path);
204 debug!("The path to save LLM interactions is {:?}", &debug_path);
205 break;
206 } else {
207 cnt += 1;
208 }
209 }
210 debug_path
211 } else {
212 None
213 };
214
215 LLM {
216 llm: Arc::new(LLMInner {
217 client: LLMClient::new(self.to_config()),
218 model: self.model.clone(),
219 billing,
220 llm_debug: debug_path,
221 llm_debug_index: AtomicU64::new(0),
222 default_settings: self.llm_settings.clone(),
223 }),
224 }
225 }
226}
227
228#[derive(Debug, Clone)]
229pub enum SupportedConfig {
230 Azure(AzureConfig),
231 OpenAI(OpenAIConfig),
232}
233
234#[derive(Debug, Clone)]
235pub enum LLMClient {
236 Azure(Client<AzureConfig>),
237 OpenAI(Client<OpenAIConfig>),
238}
239
240impl LLMClient {
241 pub fn new(config: SupportedConfig) -> Self {
242 match config {
243 SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
244 SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
245 }
246 }
247
248 pub async fn create_chat(
249 &self,
250 req: CreateChatCompletionRequest,
251 ) -> Result<CreateChatCompletionResponse, OpenAIError> {
252 match self {
253 Self::Azure(cl) => cl.chat().create(req).await,
254 Self::OpenAI(cl) => cl.chat().create(req).await,
255 }
256 }
257
258 pub async fn create_chat_stream(
259 &self,
260 req: CreateChatCompletionRequest,
261 ) -> Result<ChatCompletionResponseStream, OpenAIError> {
262 match self {
263 Self::Azure(cl) => cl.chat().create_stream(req).await,
264 Self::OpenAI(cl) => cl.chat().create_stream(req).await,
265 }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ModelBilling {
271 pub current: f64,
272 pub cap: f64,
273}
274
275impl Display for ModelBilling {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
278 }
279}
280
281impl ModelBilling {
282 pub fn new(cap: f64) -> Self {
283 Self { current: 0.0, cap }
284 }
285
286 pub fn in_cap(&self) -> bool {
287 self.current <= self.cap
288 }
289
290 pub fn input_tokens(
291 &mut self,
292 model: &OpenAIModel,
293 input_count: u64,
294 cached_count: u64,
295 ) -> Result<()> {
296 let pricing = model.pricing();
297
298 let cached_price = if let Some(cached) = pricing.cached_input_tokens {
299 cached
300 } else {
301 pricing.input_tokens
302 };
303
304 let cached_usd = (cached_price * (cached_count as f64)) / 1e6;
305 let raw_input_usd = (pricing.input_tokens * (input_count as f64)) / 1e6;
306
307 log::debug!(
308 "Input token usage: cached {:.2} USD, {} tokens / input: {:.2} USD, {} tokens",
309 cached_usd,
310 cached_count,
311 raw_input_usd,
312 input_count
313 );
314 self.current += cached_usd + raw_input_usd;
315
316 if self.in_cap() {
317 Ok(())
318 } else {
319 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
320 }
321 }
322
323 pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
324 let pricing = model.pricing();
325
326 let output_usd = pricing.output_tokens * (count as f64) / 1e6;
327 log::debug!("Output token usage: {} USD, {} tokens", output_usd, count);
328 self.current += output_usd;
329
330 if self.in_cap() {
331 Ok(())
332 } else {
333 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
334 }
335 }
336}
337
338#[derive(Debug, Clone)]
339pub struct LLM {
340 pub llm: Arc<LLMInner>,
341}
342
343impl Deref for LLM {
344 type Target = LLMInner;
345
346 fn deref(&self) -> &Self::Target {
347 &self.llm
348 }
349}
350
351#[derive(Debug)]
352pub struct LLMInner {
353 pub client: LLMClient,
354 pub model: OpenAIModel,
355 pub billing: RwLock<ModelBilling>,
356 pub llm_debug: Option<PathBuf>,
357 pub llm_debug_index: AtomicU64,
358 pub default_settings: LLMSettings,
359}
360
361pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
362 match msg {
363 ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
364 ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
365 ChatCompletionRequestMessage::Function(_) => "FUNCTION",
366 ChatCompletionRequestMessage::System(_) => "SYSTEM",
367 ChatCompletionRequestMessage::Tool(_) => "TOOL",
368 ChatCompletionRequestMessage::User(_) => "USER",
369 }
370}
371
372pub fn toolcall_to_string(t: &ChatCompletionMessageToolCalls) -> String {
373 match t {
374 ChatCompletionMessageToolCalls::Function(t) => {
375 format!(
376 "<toolcall name=\"{}\">\n{}\n</toolcall>",
377 &t.function.name, &t.function.arguments
378 )
379 }
380 ChatCompletionMessageToolCalls::Custom(t) => {
381 format!(
382 "<customtoolcall name=\"{}\">\n{}\n</customtoolcall>",
383 &t.custom_tool.name, &t.custom_tool.input
384 )
385 }
386 }
387}
388
389pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
390 let mut s = String::new();
391 if let Some(content) = resp.content.as_ref() {
392 s += content;
393 s += "\n";
394 }
395
396 if let Some(tools) = resp.tool_calls.as_ref() {
397 s += &tools.iter().map(|t| toolcall_to_string(t)).join("\n");
398 }
399
400 if let Some(refusal) = &resp.refusal {
401 s += refusal;
402 s += "\n";
403 }
404
405 let role = resp.role.to_string().to_uppercase();
406
407 format!("<{}>\n{}\n</{}>\n", &role, s, &role)
408}
409
410pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
411 const CONT: &str = "<cont/>\n";
412 const NONE: &str = "<none/>\n";
413 let role = completion_to_role(msg);
414 let content = match msg {
415 ChatCompletionRequestMessage::Assistant(ass) => {
416 let msg = ass
417 .content
418 .as_ref()
419 .map(|ass| match ass {
420 ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
421 ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
422 .iter()
423 .map(|v| match v {
424 ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
425 s.text.clone()
426 }
427 ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
428 rf.refusal.clone()
429 }
430 })
431 .join(CONT),
432 })
433 .unwrap_or(NONE.to_string());
434 let tool_calls = ass
435 .tool_calls
436 .iter()
437 .flatten()
438 .map(|t| toolcall_to_string(t))
439 .join("\n");
440 format!("{}\n{}", msg, tool_calls)
441 }
442 ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
443 ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
444 ChatCompletionRequestDeveloperMessageContent::Array(arr) => arr
445 .iter()
446 .map(|v| match v {
447 ChatCompletionRequestDeveloperMessageContentPart::Text(v) => v.text.clone(),
448 })
449 .join(CONT),
450 },
451 ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
452 ChatCompletionRequestMessage::System(sys) => match &sys.content {
453 ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
454 ChatCompletionRequestSystemMessageContent::Array(arr) => arr
455 .iter()
456 .map(|v| match v {
457 ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
458 })
459 .join(CONT),
460 },
461 ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
462 ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
463 ChatCompletionRequestToolMessageContent::Array(arr) => arr
464 .iter()
465 .map(|v| match v {
466 ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
467 })
468 .join(CONT),
469 },
470 ChatCompletionRequestMessage::User(usr) => match &usr.content {
471 ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
472 ChatCompletionRequestUserMessageContent::Array(arr) => arr
473 .iter()
474 .map(|v| match v {
475 ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
476 ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
477 format!("<img url=\"{}\"/>", &img.image_url.url)
478 }
479 ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
480 format!("<audio>{}</audio>", audio.input_audio.data)
481 }
482 ChatCompletionRequestUserMessageContentPart::File(f) => {
483 format!("<file>{:?}</file>", f)
484 }
485 })
486 .join(CONT),
487 },
488 };
489
490 format!("<{}>\n{}\n</{}>\n", role, content, role)
491}
492
493impl LLMInner {
494 async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
495 let mut json_fp = fpath.to_path_buf();
496 json_fp.set_file_name(format!(
497 "{}.json",
498 json_fp
499 .file_stem()
500 .ok_or_eyre(eyre!("no filename"))?
501 .to_str()
502 .ok_or_eyre(eyre!("non-utf fname"))?
503 ));
504
505 let mut fp = tokio::fs::OpenOptions::new()
506 .create(true)
507 .append(true)
508 .write(true)
509 .open(&json_fp)
510 .await?;
511 let s = match serde_json::to_string(&t) {
512 Ok(s) => s,
513 Err(_) => format!("{:?}", &t),
514 };
515 fp.write_all(s.as_bytes()).await?;
516 fp.write_all(b"\n").await?;
517 fp.flush().await?;
518
519 Ok(())
520 }
521
522 async fn save_llm_user(
523 fpath: &PathBuf,
524 user_msg: &CreateChatCompletionRequest,
525 ) -> Result<(), PromptError> {
526 let mut fp = tokio::fs::OpenOptions::new()
527 .create(true)
528 .truncate(true)
529 .write(true)
530 .open(&fpath)
531 .await?;
532 fp.write_all(b"=====================\n<Request>\n").await?;
533 for it in user_msg.messages.iter() {
534 let msg = completion_to_string(it);
535 fp.write_all(msg.as_bytes()).await?;
536 }
537
538 let mut tools = vec![];
539 for tool in user_msg
540 .tools
541 .as_ref()
542 .map(|t| t.iter())
543 .into_iter()
544 .flatten()
545 {
546 let s = match tool {
547 ChatCompletionTools::Function(tool) => {
548 format!(
549 "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
550 &tool.function.name,
551 &tool.function.description.clone().unwrap_or_default(),
552 tool.function.strict.unwrap_or_default(),
553 tool.function
554 .parameters
555 .as_ref()
556 .map(serde_json::to_string_pretty)
557 .transpose()?
558 .unwrap_or_default()
559 )
560 }
561 ChatCompletionTools::Custom(tool) => {
562 format!(
563 "<customtool name=\"{}\", description=\"{:?}\"></customtool>",
564 tool.custom.name, tool.custom.description
565 )
566 }
567 };
568 tools.push(s);
569 }
570 fp.write_all(tools.join("\n").as_bytes()).await?;
571 fp.write_all(b"\n</Request>\n=====================\n")
572 .await?;
573 fp.flush().await?;
574
575 Self::rewrite_json(fpath, user_msg).await?;
576
577 Ok(())
578 }
579
580 async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
581 let mut fp = tokio::fs::OpenOptions::new()
582 .create(false)
583 .append(true)
584 .write(true)
585 .open(&fpath)
586 .await?;
587 fp.write_all(b"=====================\n<Response>\n").await?;
588 for it in &resp.choices {
589 let msg = response_to_string(&it.message);
590 fp.write_all(msg.as_bytes()).await?;
591 }
592 fp.write_all(b"\n</Response>\n=====================\n")
593 .await?;
594 fp.flush().await?;
595
596 Self::rewrite_json(fpath, resp).await?;
597
598 Ok(())
599 }
600
601 fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
602 if let Some(output_folder) = self.llm_debug.as_ref() {
603 let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
604 let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
605 Some(fpath)
606 } else {
607 None
608 }
609 }
610
611 pub async fn prompt_once_with_retry(
613 &self,
614 sys_msg: &str,
615 user_msg: &str,
616 prefix: Option<&str>,
617 settings: Option<LLMSettings>,
618 ) -> Result<CreateChatCompletionResponse, PromptError> {
619 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
620 let sys = ChatCompletionRequestSystemMessageArgs::default()
621 .content(sys_msg)
622 .build()?;
623
624 let user = ChatCompletionRequestUserMessageArgs::default()
625 .content(user_msg)
626 .build()?;
627 let mut req = CreateChatCompletionRequestArgs::default();
628 req.messages(vec![sys.into(), user.into()])
629 .model(self.model.to_string())
630 .temperature(settings.llm_temperature)
631 .presence_penalty(settings.llm_presence_penalty)
632 .max_completion_tokens(settings.llm_max_completion_tokens);
633
634 if let Some(tc) = settings.llm_tool_choice {
635 req.tool_choice(tc);
636 }
637 if let Some(prefix) = prefix {
638 req.prompt_cache_key(prefix.to_string());
639 }
640 let req = req.build()?;
641
642 let timeout = if settings.llm_prompt_timeout == 0 {
643 Duration::MAX
644 } else {
645 Duration::from_secs(settings.llm_prompt_timeout)
646 };
647
648 self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
649 .await
650 }
651
652 pub async fn complete_once_with_retry(
653 &self,
654 req: &CreateChatCompletionRequest,
655 prefix: Option<&str>,
656 timeout: Option<Duration>,
657 retry: Option<u64>,
658 ) -> Result<CreateChatCompletionResponse, PromptError> {
659 let timeout = if let Some(timeout) = timeout {
660 timeout
661 } else {
662 Duration::MAX
663 };
664
665 let retry = if let Some(retry) = retry {
666 retry
667 } else {
668 u64::MAX
669 };
670
671 let mut last = None;
672 for idx in 0..retry {
673 match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
674 Ok(r) => {
675 last = Some(r);
676 }
677 Err(_) => {
678 warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
679 continue;
680 }
681 };
682
683 match last {
684 Some(Ok(r)) => return Ok(r),
685 Some(Err(ref e)) => {
686 warn!(
687 "Having an error {} during {} retry (timeout is {:?})",
688 e, idx, timeout
689 );
690 }
691 _ => {}
692 }
693 }
694
695 last.ok_or_eyre(eyre!("retry is zero?!"))
696 .map_err(PromptError::Other)?
697 }
698
699 pub async fn complete(
700 &self,
701 req: CreateChatCompletionRequest,
702 prefix: Option<&str>,
703 ) -> Result<CreateChatCompletionResponse, PromptError> {
704 let use_stream = self.default_settings.llm_stream;
705 let prefix = if let Some(prefix) = prefix {
706 prefix.to_string()
707 } else {
708 "llm".to_string()
709 };
710 let debug_fp = self.on_llm_debug(&prefix);
711
712 if let Some(debug_fp) = debug_fp.as_ref() {
713 if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
714 warn!("Fail to save user due to {}", e);
715 }
716 }
717
718 trace!(
719 "Sending completion request: {:?}",
720 &serde_json::to_string(&req)
721 );
722 let resp = if use_stream {
723 self.complete_streaming(req).await?
724 } else {
725 self.client.create_chat(req).await?
726 };
727
728 if let Some(debug_fp) = debug_fp.as_ref() {
729 if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
730 warn!("Fail to save resp due to {}", e);
731 }
732 }
733
734 if let Some(usage) = &resp.usage {
735 let cached = usage
736 .prompt_tokens_details
737 .as_ref()
738 .map(|v| v.cached_tokens)
739 .flatten()
740 .unwrap_or_default();
741 let input = usage.prompt_tokens - cached;
742 self.billing
743 .write()
744 .await
745 .input_tokens(&self.model, input as _, cached as _)
746 .map_err(PromptError::Other)?;
747 self.billing
748 .write()
749 .await
750 .output_tokens(&self.model, usage.completion_tokens as u64)
751 .map_err(PromptError::Other)?;
752 } else {
753 warn!("No usage?!")
754 }
755
756 info!("Model Billing: {}", &self.billing.read().await);
757 Ok(resp)
758 }
759
760 async fn complete_streaming(
761 &self,
762 mut req: CreateChatCompletionRequest,
763 ) -> Result<CreateChatCompletionResponse, PromptError> {
764 if req.stream_options.is_none() {
765 req.stream_options = Some(ChatCompletionStreamOptions {
766 include_usage: Some(true),
767 include_obfuscation: None,
768 });
769 }
770
771 let mut stream = self.client.create_chat_stream(req).await?;
772
773 let mut id: Option<String> = None;
774 let mut created: Option<u32> = None;
775 let mut model: Option<String> = None;
776 let mut service_tier = None;
777 let mut system_fingerprint = None;
778 let mut usage: Option<CompletionUsage> = None;
779
780 let mut contents: Vec<String> = Vec::new();
781 let mut finish_reasons: Vec<Option<FinishReason>> = Vec::new();
782 let mut tool_calls: Vec<Vec<ToolCallAcc>> = Vec::new();
783
784 while let Some(item) = stream.next().await {
785 let chunk: CreateChatCompletionStreamResponse = item?;
786 if id.is_none() {
787 id = Some(chunk.id.clone());
788 }
789 created = Some(chunk.created);
790 model = Some(chunk.model.clone());
791 service_tier = chunk.service_tier.clone();
792 system_fingerprint = chunk.system_fingerprint.clone();
793 if let Some(u) = chunk.usage.clone() {
794 usage = Some(u);
795 }
796
797 for ch in chunk.choices.into_iter() {
798 let idx = ch.index as usize;
799 if contents.len() <= idx {
800 contents.resize_with(idx + 1, String::new);
801 finish_reasons.resize_with(idx + 1, || None);
802 tool_calls.resize_with(idx + 1, Vec::new);
803 }
804 if let Some(delta) = ch.delta.content {
805 contents[idx].push_str(&delta);
806 }
807 if let Some(tcs) = ch.delta.tool_calls {
808 for tc in tcs.into_iter() {
809 let tc_idx = tc.index as usize;
810 if tool_calls[idx].len() <= tc_idx {
811 tool_calls[idx].resize_with(tc_idx + 1, ToolCallAcc::default);
812 }
813 let acc = &mut tool_calls[idx][tc_idx];
814 if let Some(id) = tc.id {
815 acc.id = id;
816 }
817 if let Some(func) = tc.function {
818 if let Some(name) = func.name {
819 acc.name = name;
820 }
821 if let Some(args) = func.arguments {
822 acc.arguments.push_str(&args);
823 }
824 }
825 }
826 }
827 if ch.finish_reason.is_some() {
828 finish_reasons[idx] = ch.finish_reason;
829 }
830 }
831 }
832
833 let mut choices = Vec::new();
834 for (idx, content) in contents.into_iter().enumerate() {
835 let finish_reason = finish_reasons.get(idx).cloned().unwrap_or(None);
836 let built_tool_calls = tool_calls
837 .get(idx)
838 .cloned()
839 .unwrap_or_default()
840 .into_iter()
841 .filter(|t| !t.name.trim().is_empty() || !t.arguments.trim().is_empty())
842 .map(|t| {
843 ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
844 id: if t.id.trim().is_empty() {
845 format!("toolcall-{}", idx)
846 } else {
847 t.id
848 },
849 function: FunctionCall {
850 name: t.name,
851 arguments: t.arguments,
852 },
853 })
854 })
855 .collect::<Vec<_>>();
856 let tool_calls_opt = if built_tool_calls.is_empty() {
857 None
858 } else {
859 Some(built_tool_calls)
860 };
861 choices.push(ChatChoice {
862 index: idx as u32,
863 message: ChatCompletionResponseMessage {
864 content: if content.is_empty() {
865 None
866 } else {
867 Some(content)
868 },
869 refusal: None,
870 tool_calls: tool_calls_opt,
871 annotations: None,
872 role: Role::Assistant,
873 function_call: None,
874 audio: None,
875 },
876 finish_reason,
877 logprobs: None,
878 });
879 }
880 if choices.is_empty() {
881 choices.push(ChatChoice {
882 index: 0,
883 message: ChatCompletionResponseMessage {
884 content: Some(String::new()),
885 refusal: None,
886 tool_calls: None,
887 annotations: None,
888 role: Role::Assistant,
889 function_call: None,
890 audio: None,
891 },
892 finish_reason: None,
893 logprobs: None,
894 });
895 }
896
897 Ok(CreateChatCompletionResponse {
898 id: id.unwrap_or_else(|| "stream".to_string()),
899 choices,
900 created: created.unwrap_or(0),
901 model: model.unwrap_or_else(|| self.model.to_string()),
902 service_tier,
903 system_fingerprint,
904 object: "chat.completion".to_string(),
905 usage,
906 })
907 }
908
909 pub async fn prompt_once(
910 &self,
911 sys_msg: &str,
912 user_msg: &str,
913 prefix: Option<&str>,
914 settings: Option<LLMSettings>,
915 ) -> Result<CreateChatCompletionResponse, PromptError> {
916 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
917 let sys = ChatCompletionRequestSystemMessageArgs::default()
918 .content(sys_msg)
919 .build()?;
920
921 let user = ChatCompletionRequestUserMessageArgs::default()
922 .content(user_msg)
923 .build()?;
924 let mut req = CreateChatCompletionRequestArgs::default();
925
926 if let Some(prefix) = prefix.as_ref() {
927 req.prompt_cache_key(prefix.to_string());
928 }
929 let req = req
930 .messages(vec![sys.into(), user.into()])
931 .model(self.model.to_string())
932 .temperature(settings.llm_temperature)
933 .presence_penalty(settings.llm_presence_penalty)
934 .max_completion_tokens(settings.llm_max_completion_tokens)
935 .build()?;
936 self.complete(req, prefix).await
937 }
938}