Skip to main content

openai_models/
llm.rs

1use std::{
2    fmt::{Debug, Display},
3    ops::{Deref, DerefMut},
4    path::{Path, PathBuf},
5    str::FromStr,
6    sync::{
7        Arc,
8        atomic::{AtomicU64, Ordering},
9    },
10    time::Duration,
11};
12
13use async_openai::{
14    Client,
15    config::{AzureConfig, OpenAIConfig},
16    error::OpenAIError,
17    types::chat::{
18        ChatChoice, ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
19        ChatCompletionNamedToolChoiceCustom, ChatCompletionRequestAssistantMessageContent,
20        ChatCompletionRequestAssistantMessageContentPart,
21        ChatCompletionRequestDeveloperMessageContent,
22        ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
23        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
24        ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
25        ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
26        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
27        ChatCompletionResponseMessage, ChatCompletionResponseStream, ChatCompletionStreamOptions,
28        ChatCompletionToolChoiceOption, ChatCompletionTools, CompletionUsage,
29        CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
30        CreateChatCompletionStreamResponse, CustomName, FinishReason, FunctionCall, Role,
31        ToolChoiceOptions,
32    },
33};
34use clap::Args;
35use color_eyre::{
36    Result,
37    eyre::{OptionExt, eyre},
38};
39use futures_util::StreamExt;
40use itertools::Itertools;
41use log::{debug, info, trace, warn};
42use serde::{Deserialize, Serialize};
43use tokio::{io::AsyncWriteExt, sync::RwLock};
44
45use crate::{OpenAIModel, error::PromptError};
46
47#[derive(Clone, Debug, Default)]
48struct ToolCallAcc {
49    id: String,
50    name: String,
51    arguments: String,
52}
53
54// Upstream implementation is flawed
55#[derive(Debug, Clone)]
56pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
57
58impl FromStr for LLMToolChoice {
59    type Err = PromptError;
60    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
61        Ok(match s {
62            "auto" => Self(ChatCompletionToolChoiceOption::Mode(
63                ToolChoiceOptions::Auto,
64            )),
65            "required" => Self(ChatCompletionToolChoiceOption::Mode(
66                ToolChoiceOptions::Required,
67            )),
68            "none" => Self(ChatCompletionToolChoiceOption::Mode(
69                ToolChoiceOptions::None,
70            )),
71            _ => Self(ChatCompletionToolChoiceOption::Custom(
72                ChatCompletionNamedToolChoiceCustom {
73                    custom: CustomName {
74                        name: s.to_string(),
75                    },
76                },
77            )),
78        })
79    }
80}
81
82impl Deref for LLMToolChoice {
83    type Target = ChatCompletionToolChoiceOption;
84    fn deref(&self) -> &Self::Target {
85        &self.0
86    }
87}
88
89impl DerefMut for LLMToolChoice {
90    fn deref_mut(&mut self) -> &mut Self::Target {
91        &mut self.0
92    }
93}
94
95impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
96    fn from(value: ChatCompletionToolChoiceOption) -> Self {
97        Self(value)
98    }
99}
100
101impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
102    fn from(value: LLMToolChoice) -> Self {
103        value.0
104    }
105}
106
107#[derive(Args, Clone, Debug)]
108pub struct LLMSettings {
109    #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
110    pub llm_temperature: f32,
111
112    #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
113    pub llm_presence_penalty: f32,
114
115    #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
116    pub llm_prompt_timeout: u64,
117
118    #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
119    pub llm_retry: u64,
120
121    #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
122    pub llm_max_completion_tokens: u32,
123
124    #[arg(long, env = "LLM_TOOL_CHOINCE")]
125    pub llm_tool_choice: Option<LLMToolChoice>,
126
127    #[arg(
128        long,
129        env = "LLM_STREAM",
130        default_value_t = false,
131        value_parser = clap::builder::BoolishValueParser::new()
132    )]
133    pub llm_stream: bool,
134}
135
136#[derive(Args, Clone, Debug)]
137pub struct OpenAISetup {
138    #[arg(
139        long,
140        env = "OPENAI_API_URL",
141        default_value = "https://api.openai.com/v1"
142    )]
143    pub openai_url: String,
144
145    #[arg(long, env = "AZURE_OPENAI_ENDPOINT")]
146    pub azure_openai_endpoint: Option<String>,
147
148    #[arg(long, env = "OPENAI_API_KEY")]
149    pub openai_key: Option<String>,
150
151    #[arg(long, env = "AZURE_API_DEPLOYMENT")]
152    pub azure_deployment: Option<String>,
153
154    #[arg(long, env = "AZURE_API_VERSION", default_value = "2025-01-01-preview")]
155    pub azure_api_version: String,
156
157    #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
158    pub biling_cap: f64,
159
160    #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
161    pub model: OpenAIModel,
162
163    #[arg(long, env = "LLM_DEBUG")]
164    pub llm_debug: Option<PathBuf>,
165
166    #[clap(flatten)]
167    pub llm_settings: LLMSettings,
168}
169
170impl OpenAISetup {
171    pub fn to_config(&self) -> SupportedConfig {
172        if let Some(ep) = self.azure_openai_endpoint.as_ref() {
173            let cfg = AzureConfig::new()
174                .with_api_base(ep)
175                .with_api_key(self.openai_key.clone().unwrap_or_default())
176                .with_deployment_id(
177                    self.azure_deployment
178                        .as_ref()
179                        .unwrap_or(&self.model.to_string()),
180                )
181                .with_api_version(&self.azure_api_version);
182            SupportedConfig::Azure(cfg)
183        } else {
184            let cfg = OpenAIConfig::new()
185                .with_api_base(&self.openai_url)
186                .with_api_key(self.openai_key.clone().unwrap_or_default());
187            SupportedConfig::OpenAI(cfg)
188        }
189    }
190
191    pub fn to_llm(&self) -> LLM {
192        let billing = RwLock::new(ModelBilling::new(self.biling_cap));
193
194        let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
195            let pid = std::process::id();
196
197            let mut cnt = 0u64;
198            let debug_path;
199            loop {
200                let test_path = dbg.join(format!("{}-{}", pid, cnt));
201                if !test_path.exists() {
202                    std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
203                    debug_path = Some(test_path);
204                    debug!("The path to save LLM interactions is {:?}", &debug_path);
205                    break;
206                } else {
207                    cnt += 1;
208                }
209            }
210            debug_path
211        } else {
212            None
213        };
214
215        LLM {
216            llm: Arc::new(LLMInner {
217                client: LLMClient::new(self.to_config()),
218                model: self.model.clone(),
219                billing,
220                llm_debug: debug_path,
221                llm_debug_index: AtomicU64::new(0),
222                default_settings: self.llm_settings.clone(),
223            }),
224        }
225    }
226}
227
228#[derive(Debug, Clone)]
229pub enum SupportedConfig {
230    Azure(AzureConfig),
231    OpenAI(OpenAIConfig),
232}
233
234#[derive(Debug, Clone)]
235pub enum LLMClient {
236    Azure(Client<AzureConfig>),
237    OpenAI(Client<OpenAIConfig>),
238}
239
240impl LLMClient {
241    pub fn new(config: SupportedConfig) -> Self {
242        match config {
243            SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
244            SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
245        }
246    }
247
248    pub async fn create_chat(
249        &self,
250        req: CreateChatCompletionRequest,
251    ) -> Result<CreateChatCompletionResponse, OpenAIError> {
252        match self {
253            Self::Azure(cl) => cl.chat().create(req).await,
254            Self::OpenAI(cl) => cl.chat().create(req).await,
255        }
256    }
257
258    pub async fn create_chat_stream(
259        &self,
260        req: CreateChatCompletionRequest,
261    ) -> Result<ChatCompletionResponseStream, OpenAIError> {
262        match self {
263            Self::Azure(cl) => cl.chat().create_stream(req).await,
264            Self::OpenAI(cl) => cl.chat().create_stream(req).await,
265        }
266    }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ModelBilling {
271    pub current: f64,
272    pub cap: f64,
273}
274
275impl Display for ModelBilling {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
278    }
279}
280
281impl ModelBilling {
282    pub fn new(cap: f64) -> Self {
283        Self { current: 0.0, cap }
284    }
285
286    pub fn in_cap(&self) -> bool {
287        self.current <= self.cap
288    }
289
290    pub fn input_tokens(
291        &mut self,
292        model: &OpenAIModel,
293        input_count: u64,
294        cached_count: u64,
295    ) -> Result<()> {
296        let pricing = model.pricing();
297
298        let cached_price = if let Some(cached) = pricing.cached_input_tokens {
299            cached
300        } else {
301            pricing.input_tokens
302        };
303
304        let cached_usd = (cached_price * (cached_count as f64)) / 1e6;
305        let raw_input_usd = (pricing.input_tokens * (input_count as f64)) / 1e6;
306
307        log::debug!(
308            "Input token usage: cached {:.2} USD, {} tokens / input: {:.2} USD, {} tokens",
309            cached_usd,
310            cached_count,
311            raw_input_usd,
312            input_count
313        );
314        self.current += cached_usd + raw_input_usd;
315
316        if self.in_cap() {
317            Ok(())
318        } else {
319            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
320        }
321    }
322
323    pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
324        let pricing = model.pricing();
325
326        let output_usd = pricing.output_tokens * (count as f64) / 1e6;
327        log::debug!("Output token usage: {} USD, {} tokens", output_usd, count);
328        self.current += output_usd;
329
330        if self.in_cap() {
331            Ok(())
332        } else {
333            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
334        }
335    }
336}
337
338#[derive(Debug, Clone)]
339pub struct LLM {
340    pub llm: Arc<LLMInner>,
341}
342
343impl Deref for LLM {
344    type Target = LLMInner;
345
346    fn deref(&self) -> &Self::Target {
347        &self.llm
348    }
349}
350
351#[derive(Debug)]
352pub struct LLMInner {
353    pub client: LLMClient,
354    pub model: OpenAIModel,
355    pub billing: RwLock<ModelBilling>,
356    pub llm_debug: Option<PathBuf>,
357    pub llm_debug_index: AtomicU64,
358    pub default_settings: LLMSettings,
359}
360
361pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
362    match msg {
363        ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
364        ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
365        ChatCompletionRequestMessage::Function(_) => "FUNCTION",
366        ChatCompletionRequestMessage::System(_) => "SYSTEM",
367        ChatCompletionRequestMessage::Tool(_) => "TOOL",
368        ChatCompletionRequestMessage::User(_) => "USER",
369    }
370}
371
372pub fn toolcall_to_string(t: &ChatCompletionMessageToolCalls) -> String {
373    match t {
374        ChatCompletionMessageToolCalls::Function(t) => {
375            format!(
376                "<toolcall name=\"{}\">\n{}\n</toolcall>",
377                &t.function.name, &t.function.arguments
378            )
379        }
380        ChatCompletionMessageToolCalls::Custom(t) => {
381            format!(
382                "<customtoolcall name=\"{}\">\n{}\n</customtoolcall>",
383                &t.custom_tool.name, &t.custom_tool.input
384            )
385        }
386    }
387}
388
389pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
390    let mut s = String::new();
391    if let Some(content) = resp.content.as_ref() {
392        s += content;
393        s += "\n";
394    }
395
396    if let Some(tools) = resp.tool_calls.as_ref() {
397        s += &tools.iter().map(|t| toolcall_to_string(t)).join("\n");
398    }
399
400    if let Some(refusal) = &resp.refusal {
401        s += refusal;
402        s += "\n";
403    }
404
405    let role = resp.role.to_string().to_uppercase();
406
407    format!("<{}>\n{}\n</{}>\n", &role, s, &role)
408}
409
410pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
411    const CONT: &str = "<cont/>\n";
412    const NONE: &str = "<none/>\n";
413    let role = completion_to_role(msg);
414    let content = match msg {
415        ChatCompletionRequestMessage::Assistant(ass) => {
416            let msg = ass
417                .content
418                .as_ref()
419                .map(|ass| match ass {
420                    ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
421                    ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
422                        .iter()
423                        .map(|v| match v {
424                            ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
425                                s.text.clone()
426                            }
427                            ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
428                                rf.refusal.clone()
429                            }
430                        })
431                        .join(CONT),
432                })
433                .unwrap_or(NONE.to_string());
434            let tool_calls = ass
435                .tool_calls
436                .iter()
437                .flatten()
438                .map(|t| toolcall_to_string(t))
439                .join("\n");
440            format!("{}\n{}", msg, tool_calls)
441        }
442        ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
443            ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
444            ChatCompletionRequestDeveloperMessageContent::Array(arr) => arr
445                .iter()
446                .map(|v| match v {
447                    ChatCompletionRequestDeveloperMessageContentPart::Text(v) => v.text.clone(),
448                })
449                .join(CONT),
450        },
451        ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
452        ChatCompletionRequestMessage::System(sys) => match &sys.content {
453            ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
454            ChatCompletionRequestSystemMessageContent::Array(arr) => arr
455                .iter()
456                .map(|v| match v {
457                    ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
458                })
459                .join(CONT),
460        },
461        ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
462            ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
463            ChatCompletionRequestToolMessageContent::Array(arr) => arr
464                .iter()
465                .map(|v| match v {
466                    ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
467                })
468                .join(CONT),
469        },
470        ChatCompletionRequestMessage::User(usr) => match &usr.content {
471            ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
472            ChatCompletionRequestUserMessageContent::Array(arr) => arr
473                .iter()
474                .map(|v| match v {
475                    ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
476                    ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
477                        format!("<img url=\"{}\"/>", &img.image_url.url)
478                    }
479                    ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
480                        format!("<audio>{}</audio>", audio.input_audio.data)
481                    }
482                    ChatCompletionRequestUserMessageContentPart::File(f) => {
483                        format!("<file>{:?}</file>", f)
484                    }
485                })
486                .join(CONT),
487        },
488    };
489
490    format!("<{}>\n{}\n</{}>\n", role, content, role)
491}
492
493impl LLMInner {
494    async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
495        let mut json_fp = fpath.to_path_buf();
496        json_fp.set_file_name(format!(
497            "{}.json",
498            json_fp
499                .file_stem()
500                .ok_or_eyre(eyre!("no filename"))?
501                .to_str()
502                .ok_or_eyre(eyre!("non-utf fname"))?
503        ));
504
505        let mut fp = tokio::fs::OpenOptions::new()
506            .create(true)
507            .append(true)
508            .write(true)
509            .open(&json_fp)
510            .await?;
511        let s = match serde_json::to_string(&t) {
512            Ok(s) => s,
513            Err(_) => format!("{:?}", &t),
514        };
515        fp.write_all(s.as_bytes()).await?;
516        fp.write_all(b"\n").await?;
517        fp.flush().await?;
518
519        Ok(())
520    }
521
522    async fn save_llm_user(
523        fpath: &PathBuf,
524        user_msg: &CreateChatCompletionRequest,
525    ) -> Result<(), PromptError> {
526        let mut fp = tokio::fs::OpenOptions::new()
527            .create(true)
528            .truncate(true)
529            .write(true)
530            .open(&fpath)
531            .await?;
532        fp.write_all(b"=====================\n<Request>\n").await?;
533        for it in user_msg.messages.iter() {
534            let msg = completion_to_string(it);
535            fp.write_all(msg.as_bytes()).await?;
536        }
537
538        let mut tools = vec![];
539        for tool in user_msg
540            .tools
541            .as_ref()
542            .map(|t| t.iter())
543            .into_iter()
544            .flatten()
545        {
546            let s = match tool {
547                ChatCompletionTools::Function(tool) => {
548                    format!(
549                        "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
550                        &tool.function.name,
551                        &tool.function.description.clone().unwrap_or_default(),
552                        tool.function.strict.unwrap_or_default(),
553                        tool.function
554                            .parameters
555                            .as_ref()
556                            .map(serde_json::to_string_pretty)
557                            .transpose()?
558                            .unwrap_or_default()
559                    )
560                }
561                ChatCompletionTools::Custom(tool) => {
562                    format!(
563                        "<customtool name=\"{}\", description=\"{:?}\"></customtool>",
564                        tool.custom.name, tool.custom.description
565                    )
566                }
567            };
568            tools.push(s);
569        }
570        fp.write_all(tools.join("\n").as_bytes()).await?;
571        fp.write_all(b"\n</Request>\n=====================\n")
572            .await?;
573        fp.flush().await?;
574
575        Self::rewrite_json(fpath, user_msg).await?;
576
577        Ok(())
578    }
579
580    async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
581        let mut fp = tokio::fs::OpenOptions::new()
582            .create(false)
583            .append(true)
584            .write(true)
585            .open(&fpath)
586            .await?;
587        fp.write_all(b"=====================\n<Response>\n").await?;
588        for it in &resp.choices {
589            let msg = response_to_string(&it.message);
590            fp.write_all(msg.as_bytes()).await?;
591        }
592        fp.write_all(b"\n</Response>\n=====================\n")
593            .await?;
594        fp.flush().await?;
595
596        Self::rewrite_json(fpath, resp).await?;
597
598        Ok(())
599    }
600
601    fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
602        if let Some(output_folder) = self.llm_debug.as_ref() {
603            let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
604            let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
605            Some(fpath)
606        } else {
607            None
608        }
609    }
610
611    // we use t/s to estimate a timeout to avoid infinite repeating
612    pub async fn prompt_once_with_retry(
613        &self,
614        sys_msg: &str,
615        user_msg: &str,
616        prefix: Option<&str>,
617        settings: Option<LLMSettings>,
618    ) -> Result<CreateChatCompletionResponse, PromptError> {
619        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
620        let sys = ChatCompletionRequestSystemMessageArgs::default()
621            .content(sys_msg)
622            .build()?;
623
624        let user = ChatCompletionRequestUserMessageArgs::default()
625            .content(user_msg)
626            .build()?;
627        let mut req = CreateChatCompletionRequestArgs::default();
628        req.messages(vec![sys.into(), user.into()])
629            .model(self.model.to_string())
630            .temperature(settings.llm_temperature)
631            .presence_penalty(settings.llm_presence_penalty)
632            .max_completion_tokens(settings.llm_max_completion_tokens);
633
634        if let Some(tc) = settings.llm_tool_choice {
635            req.tool_choice(tc);
636        }
637        if let Some(prefix) = prefix {
638            req.prompt_cache_key(prefix.to_string());
639        }
640        let req = req.build()?;
641
642        let timeout = if settings.llm_prompt_timeout == 0 {
643            Duration::MAX
644        } else {
645            Duration::from_secs(settings.llm_prompt_timeout)
646        };
647
648        self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
649            .await
650    }
651
652    pub async fn complete_once_with_retry(
653        &self,
654        req: &CreateChatCompletionRequest,
655        prefix: Option<&str>,
656        timeout: Option<Duration>,
657        retry: Option<u64>,
658    ) -> Result<CreateChatCompletionResponse, PromptError> {
659        let timeout = if let Some(timeout) = timeout {
660            timeout
661        } else {
662            Duration::MAX
663        };
664
665        let retry = if let Some(retry) = retry {
666            retry
667        } else {
668            u64::MAX
669        };
670
671        let mut last = None;
672        for idx in 0..retry {
673            match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
674                Ok(r) => {
675                    last = Some(r);
676                }
677                Err(_) => {
678                    warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
679                    continue;
680                }
681            };
682
683            match last {
684                Some(Ok(r)) => return Ok(r),
685                Some(Err(ref e)) => {
686                    warn!(
687                        "Having an error {} during {} retry (timeout is {:?})",
688                        e, idx, timeout
689                    );
690                }
691                _ => {}
692            }
693        }
694
695        last.ok_or_eyre(eyre!("retry is zero?!"))
696            .map_err(PromptError::Other)?
697    }
698
699    pub async fn complete(
700        &self,
701        req: CreateChatCompletionRequest,
702        prefix: Option<&str>,
703    ) -> Result<CreateChatCompletionResponse, PromptError> {
704        let use_stream = self.default_settings.llm_stream;
705        let prefix = if let Some(prefix) = prefix {
706            prefix.to_string()
707        } else {
708            "llm".to_string()
709        };
710        let debug_fp = self.on_llm_debug(&prefix);
711
712        if let Some(debug_fp) = debug_fp.as_ref() {
713            if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
714                warn!("Fail to save user due to {}", e);
715            }
716        }
717
718        trace!(
719            "Sending completion request: {:?}",
720            &serde_json::to_string(&req)
721        );
722        let resp = if use_stream {
723            self.complete_streaming(req).await?
724        } else {
725            self.client.create_chat(req).await?
726        };
727
728        if let Some(debug_fp) = debug_fp.as_ref() {
729            if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
730                warn!("Fail to save resp due to {}", e);
731            }
732        }
733
734        if let Some(usage) = &resp.usage {
735            let cached = usage
736                .prompt_tokens_details
737                .as_ref()
738                .map(|v| v.cached_tokens)
739                .flatten()
740                .unwrap_or_default();
741            let input = usage.prompt_tokens - cached;
742            self.billing
743                .write()
744                .await
745                .input_tokens(&self.model, input as _, cached as _)
746                .map_err(PromptError::Other)?;
747            self.billing
748                .write()
749                .await
750                .output_tokens(&self.model, usage.completion_tokens as u64)
751                .map_err(PromptError::Other)?;
752        } else {
753            warn!("No usage?!")
754        }
755
756        info!("Model Billing: {}", &self.billing.read().await);
757        Ok(resp)
758    }
759
760    async fn complete_streaming(
761        &self,
762        mut req: CreateChatCompletionRequest,
763    ) -> Result<CreateChatCompletionResponse, PromptError> {
764        if req.stream_options.is_none() {
765            req.stream_options = Some(ChatCompletionStreamOptions {
766                include_usage: Some(true),
767                include_obfuscation: None,
768            });
769        }
770
771        let mut stream = self.client.create_chat_stream(req).await?;
772
773        let mut id: Option<String> = None;
774        let mut created: Option<u32> = None;
775        let mut model: Option<String> = None;
776        let mut service_tier = None;
777        let mut system_fingerprint = None;
778        let mut usage: Option<CompletionUsage> = None;
779
780        let mut contents: Vec<String> = Vec::new();
781        let mut finish_reasons: Vec<Option<FinishReason>> = Vec::new();
782        let mut tool_calls: Vec<Vec<ToolCallAcc>> = Vec::new();
783
784        while let Some(item) = stream.next().await {
785            let chunk: CreateChatCompletionStreamResponse = item?;
786            if id.is_none() {
787                id = Some(chunk.id.clone());
788            }
789            created = Some(chunk.created);
790            model = Some(chunk.model.clone());
791            service_tier = chunk.service_tier.clone();
792            system_fingerprint = chunk.system_fingerprint.clone();
793            if let Some(u) = chunk.usage.clone() {
794                usage = Some(u);
795            }
796
797            for ch in chunk.choices.into_iter() {
798                let idx = ch.index as usize;
799                if contents.len() <= idx {
800                    contents.resize_with(idx + 1, String::new);
801                    finish_reasons.resize_with(idx + 1, || None);
802                    tool_calls.resize_with(idx + 1, Vec::new);
803                }
804                if let Some(delta) = ch.delta.content {
805                    contents[idx].push_str(&delta);
806                }
807                if let Some(tcs) = ch.delta.tool_calls {
808                    for tc in tcs.into_iter() {
809                        let tc_idx = tc.index as usize;
810                        if tool_calls[idx].len() <= tc_idx {
811                            tool_calls[idx].resize_with(tc_idx + 1, ToolCallAcc::default);
812                        }
813                        let acc = &mut tool_calls[idx][tc_idx];
814                        if let Some(id) = tc.id {
815                            acc.id = id;
816                        }
817                        if let Some(func) = tc.function {
818                            if let Some(name) = func.name {
819                                acc.name = name;
820                            }
821                            if let Some(args) = func.arguments {
822                                acc.arguments.push_str(&args);
823                            }
824                        }
825                    }
826                }
827                if ch.finish_reason.is_some() {
828                    finish_reasons[idx] = ch.finish_reason;
829                }
830            }
831        }
832
833        let mut choices = Vec::new();
834        for (idx, content) in contents.into_iter().enumerate() {
835            let finish_reason = finish_reasons.get(idx).cloned().unwrap_or(None);
836            let built_tool_calls = tool_calls
837                .get(idx)
838                .cloned()
839                .unwrap_or_default()
840                .into_iter()
841                .filter(|t| !t.name.trim().is_empty() || !t.arguments.trim().is_empty())
842                .map(|t| {
843                    ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
844                        id: if t.id.trim().is_empty() {
845                            format!("toolcall-{}", idx)
846                        } else {
847                            t.id
848                        },
849                        function: FunctionCall {
850                            name: t.name,
851                            arguments: t.arguments,
852                        },
853                    })
854                })
855                .collect::<Vec<_>>();
856            let tool_calls_opt = if built_tool_calls.is_empty() {
857                None
858            } else {
859                Some(built_tool_calls)
860            };
861            choices.push(ChatChoice {
862                index: idx as u32,
863                message: ChatCompletionResponseMessage {
864                    content: if content.is_empty() {
865                        None
866                    } else {
867                        Some(content)
868                    },
869                    refusal: None,
870                    tool_calls: tool_calls_opt,
871                    annotations: None,
872                    role: Role::Assistant,
873                    function_call: None,
874                    audio: None,
875                },
876                finish_reason,
877                logprobs: None,
878            });
879        }
880        if choices.is_empty() {
881            choices.push(ChatChoice {
882                index: 0,
883                message: ChatCompletionResponseMessage {
884                    content: Some(String::new()),
885                    refusal: None,
886                    tool_calls: None,
887                    annotations: None,
888                    role: Role::Assistant,
889                    function_call: None,
890                    audio: None,
891                },
892                finish_reason: None,
893                logprobs: None,
894            });
895        }
896
897        Ok(CreateChatCompletionResponse {
898            id: id.unwrap_or_else(|| "stream".to_string()),
899            choices,
900            created: created.unwrap_or(0),
901            model: model.unwrap_or_else(|| self.model.to_string()),
902            service_tier,
903            system_fingerprint,
904            object: "chat.completion".to_string(),
905            usage,
906        })
907    }
908
909    pub async fn prompt_once(
910        &self,
911        sys_msg: &str,
912        user_msg: &str,
913        prefix: Option<&str>,
914        settings: Option<LLMSettings>,
915    ) -> Result<CreateChatCompletionResponse, PromptError> {
916        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
917        let sys = ChatCompletionRequestSystemMessageArgs::default()
918            .content(sys_msg)
919            .build()?;
920
921        let user = ChatCompletionRequestUserMessageArgs::default()
922            .content(user_msg)
923            .build()?;
924        let mut req = CreateChatCompletionRequestArgs::default();
925
926        if let Some(prefix) = prefix.as_ref() {
927            req.prompt_cache_key(prefix.to_string());
928        }
929        let req = req
930            .messages(vec![sys.into(), user.into()])
931            .model(self.model.to_string())
932            .temperature(settings.llm_temperature)
933            .presence_penalty(settings.llm_presence_penalty)
934            .max_completion_tokens(settings.llm_max_completion_tokens)
935            .build()?;
936        self.complete(req, prefix).await
937    }
938}