llm_chain_openai/chatgpt/
prompt.rs1use async_openai::types::{
2 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestFunctionMessageArgs,
3 ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
4 ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
5 ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
6 CreateChatCompletionResponse, Role,
7};
8use futures::StreamExt;
9use llm_chain::prompt::{self, Prompt};
10use llm_chain::{
11 output::{Output, StreamSegment},
12 prompt::{ChatMessage, ChatMessageCollection},
13};
14
15use super::error::OpenAIInnerError;
16
17fn convert_role(role: &prompt::ChatRole) -> Role {
18 match role {
19 prompt::ChatRole::User => Role::User,
20 prompt::ChatRole::Assistant => Role::Assistant,
21 prompt::ChatRole::System => Role::System,
22 prompt::ChatRole::Other(_s) => Role::User, }
24}
25
26fn convert_openai_role(role: &Role) -> prompt::ChatRole {
27 match role {
28 Role::User => prompt::ChatRole::User,
29 Role::Assistant => prompt::ChatRole::Assistant,
30 Role::System => prompt::ChatRole::System,
31 Role::Tool => prompt::ChatRole::Other("Tool".to_string()),
32 Role::Function => prompt::ChatRole::Other("Function".to_string()),
33 }
34}
35
36fn format_chat_message(
37 message: &prompt::ChatMessage<String>,
38) -> Result<ChatCompletionRequestMessage, OpenAIInnerError> {
39 let role = convert_role(message.role());
40 let content = message.body().to_string();
41 let msg = match role {
42 Role::Assistant => ChatCompletionRequestMessage::Assistant(
43 ChatCompletionRequestAssistantMessageArgs::default()
44 .content(content)
45 .build()?,
46 ),
47 Role::System => ChatCompletionRequestMessage::System(
48 ChatCompletionRequestSystemMessageArgs::default()
49 .content(content)
50 .build()?,
51 ),
52 Role::User => ChatCompletionRequestMessage::User(
53 ChatCompletionRequestUserMessageArgs::default()
54 .content(content)
55 .build()?,
56 ),
57 Role::Tool => ChatCompletionRequestMessage::Tool(
58 ChatCompletionRequestToolMessageArgs::default()
59 .content(content)
60 .build()?,
61 ),
62 Role::Function => ChatCompletionRequestMessage::Function(
63 ChatCompletionRequestFunctionMessageArgs::default()
64 .content(content)
65 .build()?,
66 ),
67 };
68 Ok(msg)
69}
70
71pub fn format_chat_messages(
72 messages: prompt::ChatMessageCollection<String>,
73) -> Result<Vec<async_openai::types::ChatCompletionRequestMessage>, OpenAIInnerError> {
74 messages.iter().map(format_chat_message).collect()
75}
76
77pub fn create_chat_completion_request(
78 model: String,
79 prompt: &Prompt,
80 is_streaming: bool,
81) -> Result<CreateChatCompletionRequest, OpenAIInnerError> {
82 let messages = format_chat_messages(prompt.to_chat())?;
83 Ok(CreateChatCompletionRequestArgs::default()
84 .model(model)
85 .stream(is_streaming)
86 .messages(messages)
87 .build()?)
88}
89
90pub fn completion_to_output(resp: CreateChatCompletionResponse) -> Output {
91 let msg = resp.choices.first().unwrap().message.clone();
92 let mut col = ChatMessageCollection::new();
93 col.add_message(ChatMessage::new(
94 convert_openai_role(&msg.role),
95 msg.content.unwrap_or_default(), ));
97 Output::new_immediate(col.into())
98}
99
100pub fn stream_to_output(resp: ChatCompletionResponseStream) -> Output {
101 let stream = resp.flat_map(|x| {
102 let resp = x.unwrap();
104
105 let delta = resp.choices.first().unwrap().delta.clone();
106
107 let mut v = vec![];
108
109 if let Some(role) = delta.role {
110 v.push(StreamSegment::Role(convert_openai_role(&role)));
111 }
112 if let Some(content) = delta.content {
113 v.push(StreamSegment::Content(content))
114 }
115 futures::stream::iter(v)
116 });
117 Output::from_stream(stream)
118}