openai_models/
llm.rs

1use std::{
2    fmt::{Debug, Display},
3    ops::{Deref, DerefMut},
4    path::{Path, PathBuf},
5    str::FromStr,
6    sync::{
7        Arc,
8        atomic::{AtomicU64, Ordering},
9    },
10    time::Duration,
11};
12
13use async_openai::{
14    Client,
15    config::{AzureConfig, OpenAIConfig},
16    error::OpenAIError,
17    types::{
18        ChatCompletionRequestAssistantMessageContent,
19        ChatCompletionRequestAssistantMessageContentPart,
20        ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestMessage,
21        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
22        ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
23        ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
24        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
25        ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, CreateChatCompletionRequest,
26        CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
27    },
28};
29use clap::Args;
30use color_eyre::{
31    Result,
32    eyre::{OptionExt, eyre},
33};
34use itertools::Itertools;
35use log::{debug, info, trace, warn};
36use serde::{Deserialize, Serialize};
37use tokio::{io::AsyncWriteExt, sync::RwLock};
38
39use crate::{OpenAIModel, error::PromptError};
40
41// Upstream implementation is flawed
42#[derive(Debug, Clone)]
43pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
44
45impl FromStr for LLMToolChoice {
46    type Err = PromptError;
47    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
48        Ok(match s {
49            "auto" => Self(ChatCompletionToolChoiceOption::Auto),
50            "required" => Self(ChatCompletionToolChoiceOption::Required),
51            "none" => Self(ChatCompletionToolChoiceOption::None),
52            _ => Self(ChatCompletionToolChoiceOption::Named(s.into())),
53        })
54    }
55}
56
57impl Deref for LLMToolChoice {
58    type Target = ChatCompletionToolChoiceOption;
59    fn deref(&self) -> &Self::Target {
60        &self.0
61    }
62}
63
64impl DerefMut for LLMToolChoice {
65    fn deref_mut(&mut self) -> &mut Self::Target {
66        &mut self.0
67    }
68}
69
70impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
71    fn from(value: ChatCompletionToolChoiceOption) -> Self {
72        Self(value)
73    }
74}
75
76impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
77    fn from(value: LLMToolChoice) -> Self {
78        value.0
79    }
80}
81
82#[derive(Args, Clone, Debug)]
83pub struct LLMSettings {
84    #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
85    pub llm_temperature: f32,
86
87    #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
88    pub llm_presence_penalty: f32,
89
90    #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
91    pub llm_prompt_timeout: u64,
92
93    #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
94    pub llm_retry: u64,
95
96    #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
97    pub llm_max_completion_tokens: u32,
98
99    #[arg(long, env = "LLM_TOOL_CHOINCE", default_value = "auto")]
100    pub llm_tool_choice: LLMToolChoice,
101}
102
103#[derive(Args, Clone, Debug)]
104pub struct OpenAISetup {
105    #[arg(
106        long,
107        env = "OPENAI_API_URL",
108        default_value = "https://api.openai.com/v1"
109    )]
110    pub openai_url: String,
111
112    #[arg(long, env = "OPENAI_API_KEY")]
113    pub openai_key: Option<String>,
114
115    #[arg(long, env = "OPENAI_API_ENDPOINT")]
116    pub openai_endpoint: Option<String>,
117
118    #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
119    pub biling_cap: f64,
120
121    #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
122    pub model: OpenAIModel,
123
124    #[arg(long, env = "LLM_DEBUG")]
125    pub llm_debug: Option<PathBuf>,
126
127    #[clap(flatten)]
128    pub llm_settings: LLMSettings,
129}
130
131impl OpenAISetup {
132    pub fn to_config(&self) -> SupportedConfig {
133        if let Some(ep) = self.openai_endpoint.as_ref() {
134            let cfg = AzureConfig::new()
135                .with_api_base(&self.openai_url)
136                .with_api_key(self.openai_key.clone().unwrap_or_default())
137                .with_deployment_id(ep);
138            SupportedConfig::Azure(cfg)
139        } else {
140            let cfg = OpenAIConfig::new()
141                .with_api_base(&self.openai_url)
142                .with_api_key(self.openai_key.clone().unwrap_or_default());
143            SupportedConfig::OpenAI(cfg)
144        }
145    }
146
147    pub fn to_llm(&self) -> LLM {
148        let billing = RwLock::new(ModelBilling::new(self.biling_cap));
149
150        let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
151            let pid = std::process::id();
152
153            let mut cnt = 0u64;
154            let debug_path;
155            loop {
156                let test_path = dbg.join(format!("{}-{}", pid, cnt));
157                if !test_path.exists() {
158                    std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
159                    debug_path = Some(test_path);
160                    debug!("The path to save LLM interactions is {:?}", &debug_path);
161                    break;
162                } else {
163                    cnt += 1;
164                }
165            }
166            debug_path
167        } else {
168            None
169        };
170
171        LLM {
172            llm: Arc::new(LLMInner {
173                client: LLMClient::new(self.to_config()),
174                model: self.model.clone(),
175                billing,
176                llm_debug: debug_path,
177                llm_debug_index: AtomicU64::new(0),
178                default_settings: self.llm_settings.clone(),
179            }),
180        }
181    }
182}
183
184#[derive(Debug, Clone)]
185pub enum SupportedConfig {
186    Azure(AzureConfig),
187    OpenAI(OpenAIConfig),
188}
189
190#[derive(Debug, Clone)]
191pub enum LLMClient {
192    Azure(Client<AzureConfig>),
193    OpenAI(Client<OpenAIConfig>),
194}
195
196impl LLMClient {
197    pub fn new(config: SupportedConfig) -> Self {
198        match config {
199            SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
200            SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
201        }
202    }
203
204    pub async fn create_chat(
205        &self,
206        req: CreateChatCompletionRequest,
207    ) -> Result<CreateChatCompletionResponse, OpenAIError> {
208        match self {
209            Self::Azure(cl) => cl.chat().create(req).await,
210            Self::OpenAI(cl) => cl.chat().create(req).await,
211        }
212    }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ModelBilling {
217    pub current: f64,
218    pub cap: f64,
219}
220
221impl Display for ModelBilling {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
224    }
225}
226
227impl ModelBilling {
228    pub fn new(cap: f64) -> Self {
229        Self { current: 0.0, cap }
230    }
231
232    pub fn in_cap(&self) -> bool {
233        self.current <= self.cap
234    }
235
236    pub fn input_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
237        let pricing = model.pricing();
238
239        self.current += (pricing.input_tokens * (count as f64)) / 1e6;
240
241        if self.in_cap() {
242            Ok(())
243        } else {
244            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
245        }
246    }
247
248    pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
249        let pricing = model.pricing();
250
251        self.current += pricing.output_tokens * (count as f64) / 1e6;
252
253        if self.in_cap() {
254            Ok(())
255        } else {
256            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
257        }
258    }
259}
260
261#[derive(Debug, Clone)]
262pub struct LLM {
263    pub llm: Arc<LLMInner>,
264}
265
266impl Deref for LLM {
267    type Target = LLMInner;
268
269    fn deref(&self) -> &Self::Target {
270        &self.llm
271    }
272}
273
274#[derive(Debug)]
275pub struct LLMInner {
276    pub client: LLMClient,
277    pub model: OpenAIModel,
278    pub billing: RwLock<ModelBilling>,
279    pub llm_debug: Option<PathBuf>,
280    pub llm_debug_index: AtomicU64,
281    pub default_settings: LLMSettings,
282}
283
284pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
285    match msg {
286        ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
287        ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
288        ChatCompletionRequestMessage::Function(_) => "FUNCTION",
289        ChatCompletionRequestMessage::System(_) => "SYSTEM",
290        ChatCompletionRequestMessage::Tool(_) => "TOOL",
291        ChatCompletionRequestMessage::User(_) => "USER",
292    }
293}
294
295pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
296    let mut s = String::new();
297    if let Some(content) = resp.content.as_ref() {
298        s += content;
299        s += "\n";
300    }
301
302    if let Some(tools) = resp.tool_calls.as_ref() {
303        s += &tools
304            .iter()
305            .map(|t| {
306                format!(
307                    "<toolcall name=\"{}\">\n{}\n</toolcall>",
308                    &t.function.name, &t.function.arguments
309                )
310            })
311            .join("\n");
312    }
313
314    if let Some(refusal) = &resp.refusal {
315        s += refusal;
316        s += "\n";
317    }
318
319    let role = resp.role.to_string().to_uppercase();
320
321    format!("<{}>\n{}\n</{}>\n", &role, s, &role)
322}
323
324pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
325    const CONT: &str = "<cont/>\n";
326    const NONE: &str = "<none/>\n";
327    let role = completion_to_role(msg);
328    let content = match msg {
329        ChatCompletionRequestMessage::Assistant(ass) => {
330            let msg = ass
331                .content
332                .as_ref()
333                .map(|ass| match ass {
334                    ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
335                    ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
336                        .iter()
337                        .map(|v| match v {
338                            ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
339                                s.text.clone()
340                            }
341                            ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
342                                rf.refusal.clone()
343                            }
344                        })
345                        .join(CONT),
346                })
347                .unwrap_or(NONE.to_string());
348            let tool_calls = ass
349                .tool_calls
350                .iter()
351                .flatten()
352                .map(|t| {
353                    format!(
354                        "<toolcall name=\"{}\">{}</toolcall>",
355                        &t.function.name, &t.function.arguments
356                    )
357                })
358                .join("\n");
359            format!("{}\n{}", msg, tool_calls)
360        }
361        ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
362            ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
363            ChatCompletionRequestDeveloperMessageContent::Array(arr) => {
364                arr.iter().map(|v| v.text.clone()).join(CONT)
365            }
366        },
367        ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
368        ChatCompletionRequestMessage::System(sys) => match &sys.content {
369            ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
370            ChatCompletionRequestSystemMessageContent::Array(arr) => arr
371                .iter()
372                .map(|v| match v {
373                    ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
374                })
375                .join(CONT),
376        },
377        ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
378            ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
379            ChatCompletionRequestToolMessageContent::Array(arr) => arr
380                .iter()
381                .map(|v| match v {
382                    ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
383                })
384                .join(CONT),
385        },
386        ChatCompletionRequestMessage::User(usr) => match &usr.content {
387            ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
388            ChatCompletionRequestUserMessageContent::Array(arr) => arr
389                .iter()
390                .map(|v| match v {
391                    ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
392                    ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
393                        format!("<img url=\"{}\"/>", &img.image_url.url)
394                    }
395                    ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
396                        format!("<audio>{}</audio>", audio.input_audio.data)
397                    }
398                })
399                .join(CONT),
400        },
401    };
402
403    format!("<{}>\n{}\n</{}>\n", role, content, role)
404}
405
406impl LLMInner {
407    async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
408        let mut json_fp = fpath.to_path_buf();
409        json_fp.set_file_name(format!(
410            "{}.json",
411            json_fp
412                .file_stem()
413                .ok_or_eyre(eyre!("no filename"))?
414                .to_str()
415                .ok_or_eyre(eyre!("non-utf fname"))?
416        ));
417
418        let mut fp = tokio::fs::OpenOptions::new()
419            .create(true)
420            .append(true)
421            .write(true)
422            .open(&json_fp)
423            .await?;
424        let s = match serde_json::to_string(&t) {
425            Ok(s) => s,
426            Err(_) => format!("{:?}", &t),
427        };
428        fp.write_all(s.as_bytes()).await?;
429        fp.write_all(b"\n").await?;
430        fp.flush().await?;
431
432        Ok(())
433    }
434
435    async fn save_llm_user(
436        fpath: &PathBuf,
437        user_msg: &CreateChatCompletionRequest,
438    ) -> Result<(), PromptError> {
439        let mut fp = tokio::fs::OpenOptions::new()
440            .create(true)
441            .truncate(true)
442            .write(true)
443            .open(&fpath)
444            .await?;
445        fp.write_all(b"<Request>\n").await?;
446        for it in user_msg.messages.iter() {
447            let msg = completion_to_string(it);
448            fp.write_all(msg.as_bytes()).await?;
449        }
450
451        let mut tools = vec![];
452        for tool in user_msg
453            .tools
454            .as_ref()
455            .map(|t: &Vec<async_openai::types::ChatCompletionTool>| t.iter())
456            .into_iter()
457            .flatten()
458        {
459            tools.push(format!(
460                "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
461                &tool.function.name,
462                &tool.function.description.clone().unwrap_or_default(),
463                tool.function.strict.unwrap_or_default(),
464                tool.function
465                    .parameters
466                    .as_ref()
467                    .map(serde_json::to_string_pretty)
468                    .transpose()?
469                    .unwrap_or_default()
470            ));
471        }
472        fp.write_all(tools.join("\n").as_bytes()).await?;
473        fp.write_all(b"\n</Request>\n").await?;
474        fp.flush().await?;
475
476        Self::rewrite_json(fpath, user_msg).await?;
477
478        Ok(())
479    }
480
481    async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
482        let mut fp = tokio::fs::OpenOptions::new()
483            .create(false)
484            .append(true)
485            .write(true)
486            .open(&fpath)
487            .await?;
488        fp.write_all(b"<Response>\n").await?;
489        for it in &resp.choices {
490            let msg = response_to_string(&it.message);
491            fp.write_all(msg.as_bytes()).await?;
492        }
493        fp.write_all(b"\n</Response>\n").await?;
494        fp.flush().await?;
495
496        Self::rewrite_json(fpath, resp).await?;
497
498        Ok(())
499    }
500
501    fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
502        if let Some(output_folder) = self.llm_debug.as_ref() {
503            let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
504            let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
505            Some(fpath)
506        } else {
507            None
508        }
509    }
510
511    // we use t/s to estimate a timeout to avoid infinite repeating
512    pub async fn prompt_once_with_retry(
513        &self,
514        sys_msg: &str,
515        user_msg: &str,
516        prefix: Option<&str>,
517        settings: Option<LLMSettings>,
518    ) -> Result<CreateChatCompletionResponse, PromptError> {
519        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
520        let sys = ChatCompletionRequestSystemMessageArgs::default()
521            .content(sys_msg)
522            .build()?;
523
524        let user = ChatCompletionRequestUserMessageArgs::default()
525            .content(user_msg)
526            .build()?;
527        let req = CreateChatCompletionRequestArgs::default()
528            .messages(vec![sys.into(), user.into()])
529            .model(self.model.to_string())
530            .temperature(settings.llm_temperature)
531            .presence_penalty(settings.llm_presence_penalty)
532            .max_completion_tokens(settings.llm_max_completion_tokens)
533            .tool_choice(settings.llm_tool_choice)
534            .build()?;
535
536        let timeout = if settings.llm_prompt_timeout == 0 {
537            Duration::MAX
538        } else {
539            Duration::from_secs(settings.llm_prompt_timeout)
540        };
541
542        self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
543            .await
544    }
545
546    pub async fn complete_once_with_retry(
547        &self,
548        req: &CreateChatCompletionRequest,
549        prefix: Option<&str>,
550        timeout: Option<Duration>,
551        retry: Option<u64>,
552    ) -> Result<CreateChatCompletionResponse, PromptError> {
553        let timeout = if let Some(timeout) = timeout {
554            timeout
555        } else {
556            Duration::MAX
557        };
558
559        let retry = if let Some(retry) = retry {
560            retry
561        } else {
562            u64::MAX
563        };
564
565        let mut last = None;
566        for idx in 0..retry {
567            match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
568                Ok(r) => {
569                    last = Some(r);
570                }
571                Err(_) => {
572                    warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
573                    continue;
574                }
575            };
576
577            match last {
578                Some(Ok(r)) => return Ok(r),
579                Some(Err(ref e)) => {
580                    warn!(
581                        "Having an error {} during {} retry (timeout is {:?})",
582                        e, idx, timeout
583                    );
584                }
585                _ => {}
586            }
587        }
588
589        last.ok_or_eyre(eyre!("retry is zero?!"))
590            .map_err(PromptError::Other)?
591    }
592
593    pub async fn complete(
594        &self,
595        req: CreateChatCompletionRequest,
596        prefix: Option<&str>,
597    ) -> Result<CreateChatCompletionResponse, PromptError> {
598        let prefix = if let Some(prefix) = prefix {
599            prefix.to_string()
600        } else {
601            "llm".to_string()
602        };
603        let debug_fp = self.on_llm_debug(&prefix);
604
605        if let Some(debug_fp) = debug_fp.as_ref() {
606            if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
607                warn!("Fail to save user due to {}", e);
608            }
609        }
610
611        trace!("Sending completion request: {:?}", &req);
612        let resp = self.client.create_chat(req).await?;
613
614        if let Some(debug_fp) = debug_fp.as_ref() {
615            if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
616                warn!("Fail to save resp due to {}", e);
617            }
618        }
619
620        if let Some(usage) = &resp.usage {
621            self.billing
622                .write()
623                .await
624                .input_tokens(&self.model, usage.prompt_tokens as u64)
625                .map_err(PromptError::Other)?;
626            self.billing
627                .write()
628                .await
629                .output_tokens(&self.model, usage.completion_tokens as u64)
630                .map_err(PromptError::Other)?;
631        } else {
632            warn!("No usage?!")
633        }
634
635        info!("Model Billing: {}", &self.billing.read().await);
636        Ok(resp)
637    }
638
639    pub async fn prompt_once(
640        &self,
641        sys_msg: &str,
642        user_msg: &str,
643        prefix: Option<&str>,
644        settings: Option<LLMSettings>,
645    ) -> Result<CreateChatCompletionResponse, PromptError> {
646        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
647        let sys = ChatCompletionRequestSystemMessageArgs::default()
648            .content(sys_msg)
649            .build()?;
650
651        let user = ChatCompletionRequestUserMessageArgs::default()
652            .content(user_msg)
653            .build()?;
654        let req = CreateChatCompletionRequestArgs::default()
655            .messages(vec![sys.into(), user.into()])
656            .model(self.model.to_string())
657            .temperature(settings.llm_temperature)
658            .presence_penalty(settings.llm_presence_penalty)
659            .max_completion_tokens(settings.llm_max_completion_tokens)
660            .build()?;
661        self.complete(req, prefix).await
662    }
663}