1use std::{
2 fmt::{Debug, Display},
3 ops::{Deref, DerefMut},
4 path::{Path, PathBuf},
5 str::FromStr,
6 sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9 },
10 time::Duration,
11};
12
13use async_openai::{
14 Client,
15 config::{AzureConfig, OpenAIConfig},
16 error::OpenAIError,
17 types::{
18 ChatCompletionRequestAssistantMessageContent,
19 ChatCompletionRequestAssistantMessageContentPart,
20 ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestMessage,
21 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
22 ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
23 ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
24 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
25 ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, CreateChatCompletionRequest,
26 CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
27 },
28};
29use clap::Args;
30use color_eyre::{
31 Result,
32 eyre::{OptionExt, eyre},
33};
34use itertools::Itertools;
35use log::{debug, info, trace, warn};
36use serde::{Deserialize, Serialize};
37use tokio::{io::AsyncWriteExt, sync::RwLock};
38
39use crate::{OpenAIModel, error::PromptError};
40
41#[derive(Debug, Clone)]
43pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
44
45impl FromStr for LLMToolChoice {
46 type Err = PromptError;
47 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
48 Ok(match s {
49 "auto" => Self(ChatCompletionToolChoiceOption::Auto),
50 "required" => Self(ChatCompletionToolChoiceOption::Required),
51 "none" => Self(ChatCompletionToolChoiceOption::None),
52 _ => Self(ChatCompletionToolChoiceOption::Named(s.into())),
53 })
54 }
55}
56
57impl Deref for LLMToolChoice {
58 type Target = ChatCompletionToolChoiceOption;
59 fn deref(&self) -> &Self::Target {
60 &self.0
61 }
62}
63
64impl DerefMut for LLMToolChoice {
65 fn deref_mut(&mut self) -> &mut Self::Target {
66 &mut self.0
67 }
68}
69
70impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
71 fn from(value: ChatCompletionToolChoiceOption) -> Self {
72 Self(value)
73 }
74}
75
76impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
77 fn from(value: LLMToolChoice) -> Self {
78 value.0
79 }
80}
81
82#[derive(Args, Clone, Debug)]
83pub struct LLMSettings {
84 #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
85 pub llm_temperature: f32,
86
87 #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
88 pub llm_presence_penalty: f32,
89
90 #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
91 pub llm_prompt_timeout: u64,
92
93 #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
94 pub llm_retry: u64,
95
96 #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
97 pub llm_max_completion_tokens: u32,
98
99 #[arg(long, env = "LLM_TOOL_CHOINCE", default_value = "auto")]
100 pub llm_tool_choice: LLMToolChoice,
101}
102
103#[derive(Args, Clone, Debug)]
104pub struct OpenAISetup {
105 #[arg(
106 long,
107 env = "OPENAI_API_URL",
108 default_value = "https://api.openai.com/v1"
109 )]
110 pub openai_url: String,
111
112 #[arg(long, env = "OPENAI_API_KEY")]
113 pub openai_key: Option<String>,
114
115 #[arg(long, env = "OPENAI_API_ENDPOINT")]
116 pub openai_endpoint: Option<String>,
117
118 #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
119 pub biling_cap: f64,
120
121 #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
122 pub model: OpenAIModel,
123
124 #[arg(long, env = "LLM_DEBUG")]
125 pub llm_debug: Option<PathBuf>,
126
127 #[clap(flatten)]
128 pub llm_settings: LLMSettings,
129}
130
131impl OpenAISetup {
132 pub fn to_config(&self) -> SupportedConfig {
133 if let Some(ep) = self.openai_endpoint.as_ref() {
134 let cfg = AzureConfig::new()
135 .with_api_base(&self.openai_url)
136 .with_api_key(self.openai_key.clone().unwrap_or_default())
137 .with_deployment_id(ep);
138 SupportedConfig::Azure(cfg)
139 } else {
140 let cfg = OpenAIConfig::new()
141 .with_api_base(&self.openai_url)
142 .with_api_key(self.openai_key.clone().unwrap_or_default());
143 SupportedConfig::OpenAI(cfg)
144 }
145 }
146
147 pub fn to_llm(&self) -> LLM {
148 let billing = RwLock::new(ModelBilling::new(self.biling_cap));
149
150 let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
151 let pid = std::process::id();
152
153 let mut cnt = 0u64;
154 let debug_path;
155 loop {
156 let test_path = dbg.join(format!("{}-{}", pid, cnt));
157 if !test_path.exists() {
158 std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
159 debug_path = Some(test_path);
160 debug!("The path to save LLM interactions is {:?}", &debug_path);
161 break;
162 } else {
163 cnt += 1;
164 }
165 }
166 debug_path
167 } else {
168 None
169 };
170
171 LLM {
172 llm: Arc::new(LLMInner {
173 client: LLMClient::new(self.to_config()),
174 model: self.model.clone(),
175 billing,
176 llm_debug: debug_path,
177 llm_debug_index: AtomicU64::new(0),
178 default_settings: self.llm_settings.clone(),
179 }),
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
185pub enum SupportedConfig {
186 Azure(AzureConfig),
187 OpenAI(OpenAIConfig),
188}
189
190#[derive(Debug, Clone)]
191pub enum LLMClient {
192 Azure(Client<AzureConfig>),
193 OpenAI(Client<OpenAIConfig>),
194}
195
196impl LLMClient {
197 pub fn new(config: SupportedConfig) -> Self {
198 match config {
199 SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
200 SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
201 }
202 }
203
204 pub async fn create_chat(
205 &self,
206 req: CreateChatCompletionRequest,
207 ) -> Result<CreateChatCompletionResponse, OpenAIError> {
208 match self {
209 Self::Azure(cl) => cl.chat().create(req).await,
210 Self::OpenAI(cl) => cl.chat().create(req).await,
211 }
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ModelBilling {
217 pub current: f64,
218 pub cap: f64,
219}
220
221impl Display for ModelBilling {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
224 }
225}
226
227impl ModelBilling {
228 pub fn new(cap: f64) -> Self {
229 Self { current: 0.0, cap }
230 }
231
232 pub fn in_cap(&self) -> bool {
233 self.current <= self.cap
234 }
235
236 pub fn input_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
237 let pricing = model.pricing();
238
239 self.current += (pricing.input_tokens * (count as f64)) / 1e6;
240
241 if self.in_cap() {
242 Ok(())
243 } else {
244 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
245 }
246 }
247
248 pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
249 let pricing = model.pricing();
250
251 self.current += pricing.output_tokens * (count as f64) / 1e6;
252
253 if self.in_cap() {
254 Ok(())
255 } else {
256 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
257 }
258 }
259}
260
261#[derive(Debug, Clone)]
262pub struct LLM {
263 pub llm: Arc<LLMInner>,
264}
265
266impl Deref for LLM {
267 type Target = LLMInner;
268
269 fn deref(&self) -> &Self::Target {
270 &self.llm
271 }
272}
273
274#[derive(Debug)]
275pub struct LLMInner {
276 pub client: LLMClient,
277 pub model: OpenAIModel,
278 pub billing: RwLock<ModelBilling>,
279 pub llm_debug: Option<PathBuf>,
280 pub llm_debug_index: AtomicU64,
281 pub default_settings: LLMSettings,
282}
283
284pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
285 match msg {
286 ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
287 ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
288 ChatCompletionRequestMessage::Function(_) => "FUNCTION",
289 ChatCompletionRequestMessage::System(_) => "SYSTEM",
290 ChatCompletionRequestMessage::Tool(_) => "TOOL",
291 ChatCompletionRequestMessage::User(_) => "USER",
292 }
293}
294
295pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
296 let mut s = String::new();
297 if let Some(content) = resp.content.as_ref() {
298 s += content;
299 s += "\n";
300 }
301
302 if let Some(tools) = resp.tool_calls.as_ref() {
303 s += &tools
304 .iter()
305 .map(|t| {
306 format!(
307 "<toolcall name=\"{}\">\n{}\n</toolcall>",
308 &t.function.name, &t.function.arguments
309 )
310 })
311 .join("\n");
312 }
313
314 if let Some(refusal) = &resp.refusal {
315 s += refusal;
316 s += "\n";
317 }
318
319 let role = resp.role.to_string().to_uppercase();
320
321 format!("<{}>\n{}\n</{}>\n", &role, s, &role)
322}
323
324pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
325 const CONT: &str = "<cont/>\n";
326 const NONE: &str = "<none/>\n";
327 let role = completion_to_role(msg);
328 let content = match msg {
329 ChatCompletionRequestMessage::Assistant(ass) => {
330 let msg = ass
331 .content
332 .as_ref()
333 .map(|ass| match ass {
334 ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
335 ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
336 .iter()
337 .map(|v| match v {
338 ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
339 s.text.clone()
340 }
341 ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
342 rf.refusal.clone()
343 }
344 })
345 .join(CONT),
346 })
347 .unwrap_or(NONE.to_string());
348 let tool_calls = ass
349 .tool_calls
350 .iter()
351 .flatten()
352 .map(|t| {
353 format!(
354 "<toolcall name=\"{}\">{}</toolcall>",
355 &t.function.name, &t.function.arguments
356 )
357 })
358 .join("\n");
359 format!("{}\n{}", msg, tool_calls)
360 }
361 ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
362 ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
363 ChatCompletionRequestDeveloperMessageContent::Array(arr) => {
364 arr.iter().map(|v| v.text.clone()).join(CONT)
365 }
366 },
367 ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
368 ChatCompletionRequestMessage::System(sys) => match &sys.content {
369 ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
370 ChatCompletionRequestSystemMessageContent::Array(arr) => arr
371 .iter()
372 .map(|v| match v {
373 ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
374 })
375 .join(CONT),
376 },
377 ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
378 ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
379 ChatCompletionRequestToolMessageContent::Array(arr) => arr
380 .iter()
381 .map(|v| match v {
382 ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
383 })
384 .join(CONT),
385 },
386 ChatCompletionRequestMessage::User(usr) => match &usr.content {
387 ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
388 ChatCompletionRequestUserMessageContent::Array(arr) => arr
389 .iter()
390 .map(|v| match v {
391 ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
392 ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
393 format!("<img url=\"{}\"/>", &img.image_url.url)
394 }
395 ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
396 format!("<audio>{}</audio>", audio.input_audio.data)
397 }
398 })
399 .join(CONT),
400 },
401 };
402
403 format!("<{}>\n{}\n</{}>\n", role, content, role)
404}
405
406impl LLMInner {
407 async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
408 let mut json_fp = fpath.to_path_buf();
409 json_fp.set_file_name(format!(
410 "{}.json",
411 json_fp
412 .file_stem()
413 .ok_or_eyre(eyre!("no filename"))?
414 .to_str()
415 .ok_or_eyre(eyre!("non-utf fname"))?
416 ));
417
418 let mut fp = tokio::fs::OpenOptions::new()
419 .create(true)
420 .append(true)
421 .write(true)
422 .open(&json_fp)
423 .await?;
424 let s = match serde_json::to_string(&t) {
425 Ok(s) => s,
426 Err(_) => format!("{:?}", &t),
427 };
428 fp.write_all(s.as_bytes()).await?;
429 fp.write_all(b"\n").await?;
430 fp.flush().await?;
431
432 Ok(())
433 }
434
435 async fn save_llm_user(
436 fpath: &PathBuf,
437 user_msg: &CreateChatCompletionRequest,
438 ) -> Result<(), PromptError> {
439 let mut fp = tokio::fs::OpenOptions::new()
440 .create(true)
441 .truncate(true)
442 .write(true)
443 .open(&fpath)
444 .await?;
445 fp.write_all(b"<Request>\n").await?;
446 for it in user_msg.messages.iter() {
447 let msg = completion_to_string(it);
448 fp.write_all(msg.as_bytes()).await?;
449 }
450
451 let mut tools = vec![];
452 for tool in user_msg
453 .tools
454 .as_ref()
455 .map(|t: &Vec<async_openai::types::ChatCompletionTool>| t.iter())
456 .into_iter()
457 .flatten()
458 {
459 tools.push(format!(
460 "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
461 &tool.function.name,
462 &tool.function.description.clone().unwrap_or_default(),
463 tool.function.strict.unwrap_or_default(),
464 tool.function
465 .parameters
466 .as_ref()
467 .map(serde_json::to_string_pretty)
468 .transpose()?
469 .unwrap_or_default()
470 ));
471 }
472 fp.write_all(tools.join("\n").as_bytes()).await?;
473 fp.write_all(b"\n</Request>\n").await?;
474 fp.flush().await?;
475
476 Self::rewrite_json(fpath, user_msg).await?;
477
478 Ok(())
479 }
480
481 async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
482 let mut fp = tokio::fs::OpenOptions::new()
483 .create(false)
484 .append(true)
485 .write(true)
486 .open(&fpath)
487 .await?;
488 fp.write_all(b"<Response>\n").await?;
489 for it in &resp.choices {
490 let msg = response_to_string(&it.message);
491 fp.write_all(msg.as_bytes()).await?;
492 }
493 fp.write_all(b"\n</Response>\n").await?;
494 fp.flush().await?;
495
496 Self::rewrite_json(fpath, resp).await?;
497
498 Ok(())
499 }
500
501 fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
502 if let Some(output_folder) = self.llm_debug.as_ref() {
503 let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
504 let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
505 Some(fpath)
506 } else {
507 None
508 }
509 }
510
511 pub async fn prompt_once_with_retry(
513 &self,
514 sys_msg: &str,
515 user_msg: &str,
516 prefix: Option<&str>,
517 settings: Option<LLMSettings>,
518 ) -> Result<CreateChatCompletionResponse, PromptError> {
519 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
520 let sys = ChatCompletionRequestSystemMessageArgs::default()
521 .content(sys_msg)
522 .build()?;
523
524 let user = ChatCompletionRequestUserMessageArgs::default()
525 .content(user_msg)
526 .build()?;
527 let req = CreateChatCompletionRequestArgs::default()
528 .messages(vec![sys.into(), user.into()])
529 .model(self.model.to_string())
530 .temperature(settings.llm_temperature)
531 .presence_penalty(settings.llm_presence_penalty)
532 .max_completion_tokens(settings.llm_max_completion_tokens)
533 .tool_choice(settings.llm_tool_choice)
534 .build()?;
535
536 let timeout = if settings.llm_prompt_timeout == 0 {
537 Duration::MAX
538 } else {
539 Duration::from_secs(settings.llm_prompt_timeout)
540 };
541
542 self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
543 .await
544 }
545
546 pub async fn complete_once_with_retry(
547 &self,
548 req: &CreateChatCompletionRequest,
549 prefix: Option<&str>,
550 timeout: Option<Duration>,
551 retry: Option<u64>,
552 ) -> Result<CreateChatCompletionResponse, PromptError> {
553 let timeout = if let Some(timeout) = timeout {
554 timeout
555 } else {
556 Duration::MAX
557 };
558
559 let retry = if let Some(retry) = retry {
560 retry
561 } else {
562 u64::MAX
563 };
564
565 let mut last = None;
566 for idx in 0..retry {
567 match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
568 Ok(r) => {
569 last = Some(r);
570 }
571 Err(_) => {
572 warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
573 continue;
574 }
575 };
576
577 match last {
578 Some(Ok(r)) => return Ok(r),
579 Some(Err(ref e)) => {
580 warn!(
581 "Having an error {} during {} retry (timeout is {:?})",
582 e, idx, timeout
583 );
584 }
585 _ => {}
586 }
587 }
588
589 last.ok_or_eyre(eyre!("retry is zero?!"))
590 .map_err(PromptError::Other)?
591 }
592
593 pub async fn complete(
594 &self,
595 req: CreateChatCompletionRequest,
596 prefix: Option<&str>,
597 ) -> Result<CreateChatCompletionResponse, PromptError> {
598 let prefix = if let Some(prefix) = prefix {
599 prefix.to_string()
600 } else {
601 "llm".to_string()
602 };
603 let debug_fp = self.on_llm_debug(&prefix);
604
605 if let Some(debug_fp) = debug_fp.as_ref() {
606 if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
607 warn!("Fail to save user due to {}", e);
608 }
609 }
610
611 trace!("Sending completion request: {:?}", &req);
612 let resp = self.client.create_chat(req).await?;
613
614 if let Some(debug_fp) = debug_fp.as_ref() {
615 if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
616 warn!("Fail to save resp due to {}", e);
617 }
618 }
619
620 if let Some(usage) = &resp.usage {
621 self.billing
622 .write()
623 .await
624 .input_tokens(&self.model, usage.prompt_tokens as u64)
625 .map_err(PromptError::Other)?;
626 self.billing
627 .write()
628 .await
629 .output_tokens(&self.model, usage.completion_tokens as u64)
630 .map_err(PromptError::Other)?;
631 } else {
632 warn!("No usage?!")
633 }
634
635 info!("Model Billing: {}", &self.billing.read().await);
636 Ok(resp)
637 }
638
639 pub async fn prompt_once(
640 &self,
641 sys_msg: &str,
642 user_msg: &str,
643 prefix: Option<&str>,
644 settings: Option<LLMSettings>,
645 ) -> Result<CreateChatCompletionResponse, PromptError> {
646 let settings = settings.unwrap_or_else(|| self.default_settings.clone());
647 let sys = ChatCompletionRequestSystemMessageArgs::default()
648 .content(sys_msg)
649 .build()?;
650
651 let user = ChatCompletionRequestUserMessageArgs::default()
652 .content(user_msg)
653 .build()?;
654 let req = CreateChatCompletionRequestArgs::default()
655 .messages(vec![sys.into(), user.into()])
656 .model(self.model.to_string())
657 .temperature(settings.llm_temperature)
658 .presence_penalty(settings.llm_presence_penalty)
659 .max_completion_tokens(settings.llm_max_completion_tokens)
660 .build()?;
661 self.complete(req, prefix).await
662 }
663}