llm_chain_openai/chatgpt/
prompt.rs

1use async_openai::types::{
2    ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestFunctionMessageArgs,
3    ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
4    ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
5    ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
6    CreateChatCompletionResponse, Role,
7};
8use futures::StreamExt;
9use llm_chain::prompt::{self, Prompt};
10use llm_chain::{
11    output::{Output, StreamSegment},
12    prompt::{ChatMessage, ChatMessageCollection},
13};
14
15use super::error::OpenAIInnerError;
16
17fn convert_role(role: &prompt::ChatRole) -> Role {
18    match role {
19        prompt::ChatRole::User => Role::User,
20        prompt::ChatRole::Assistant => Role::Assistant,
21        prompt::ChatRole::System => Role::System,
22        prompt::ChatRole::Other(_s) => Role::User, // other roles are not supported by OpenAI
23    }
24}
25
26fn convert_openai_role(role: &Role) -> prompt::ChatRole {
27    match role {
28        Role::User => prompt::ChatRole::User,
29        Role::Assistant => prompt::ChatRole::Assistant,
30        Role::System => prompt::ChatRole::System,
31        Role::Tool => prompt::ChatRole::Other("Tool".to_string()),
32        Role::Function => prompt::ChatRole::Other("Function".to_string()),
33    }
34}
35
36fn format_chat_message(
37    message: &prompt::ChatMessage<String>,
38) -> Result<ChatCompletionRequestMessage, OpenAIInnerError> {
39    let role = convert_role(message.role());
40    let content = message.body().to_string();
41    let msg = match role {
42        Role::Assistant => ChatCompletionRequestMessage::Assistant(
43            ChatCompletionRequestAssistantMessageArgs::default()
44                .content(content)
45                .build()?,
46        ),
47        Role::System => ChatCompletionRequestMessage::System(
48            ChatCompletionRequestSystemMessageArgs::default()
49                .content(content)
50                .build()?,
51        ),
52        Role::User => ChatCompletionRequestMessage::User(
53            ChatCompletionRequestUserMessageArgs::default()
54                .content(content)
55                .build()?,
56        ),
57        Role::Tool => ChatCompletionRequestMessage::Tool(
58            ChatCompletionRequestToolMessageArgs::default()
59                .content(content)
60                .build()?,
61        ),
62        Role::Function => ChatCompletionRequestMessage::Function(
63            ChatCompletionRequestFunctionMessageArgs::default()
64                .content(content)
65                .build()?,
66        ),
67    };
68    Ok(msg)
69}
70
71pub fn format_chat_messages(
72    messages: prompt::ChatMessageCollection<String>,
73) -> Result<Vec<async_openai::types::ChatCompletionRequestMessage>, OpenAIInnerError> {
74    messages.iter().map(format_chat_message).collect()
75}
76
77pub fn create_chat_completion_request(
78    model: String,
79    prompt: &Prompt,
80    is_streaming: bool,
81) -> Result<CreateChatCompletionRequest, OpenAIInnerError> {
82    let messages = format_chat_messages(prompt.to_chat())?;
83    Ok(CreateChatCompletionRequestArgs::default()
84        .model(model)
85        .stream(is_streaming)
86        .messages(messages)
87        .build()?)
88}
89
90pub fn completion_to_output(resp: CreateChatCompletionResponse) -> Output {
91    let msg = resp.choices.first().unwrap().message.clone();
92    let mut col = ChatMessageCollection::new();
93    col.add_message(ChatMessage::new(
94        convert_openai_role(&msg.role),
95        msg.content.unwrap_or_default(), // "" for missing
96    ));
97    Output::new_immediate(col.into())
98}
99
100pub fn stream_to_output(resp: ChatCompletionResponseStream) -> Output {
101    let stream = resp.flat_map(|x| {
102        // Can't unwrap here!!
103        let resp = x.unwrap();
104
105        let delta = resp.choices.first().unwrap().delta.clone();
106
107        let mut v = vec![];
108
109        if let Some(role) = delta.role {
110            v.push(StreamSegment::Role(convert_openai_role(&role)));
111        }
112        if let Some(content) = delta.content {
113            v.push(StreamSegment::Content(content))
114        }
115        futures::stream::iter(v)
116    });
117    Output::from_stream(stream)
118}