Skip to main content

golem_ai_llm_openrouter/
lib.rs

1mod client;
2pub mod config;
3mod conversions;
4
5pub use config::OpenRouterConfig;
6#[cfg(feature = "golem")]
7pub use config::OpenRouterHostConfig;
8
9use crate::client::{ChatCompletionChunk, CompletionsApi, CompletionsRequest, FunctionCall};
10use crate::conversions::{
11    convert_finish_reason, convert_usage, events_to_request, process_response,
12};
13use golem_ai_llm::chat_stream::{LlmChatStream, LlmChatStreamState};
14use golem_ai_llm::durability::{DurableLLM, ExtendedLlmProvider};
15use golem_ai_llm::error::error_code_from_status;
16use golem_ai_llm::event_source::EventSource;
17use golem_ai_llm::model::{
18    ChatStream, Config, ContentPart, Error, ErrorCode, Event, FinishReason, Message, Response,
19    ResponseMetadata, Role, StreamDelta, StreamEvent, ToolCall,
20};
21use golem_ai_llm::wasi_compat::Pollable;
22use golem_ai_llm::LlmProvider;
23use golem_wasi_http::StatusCode;
24use log::trace;
25use std::cell::{Ref, RefCell, RefMut};
26use std::collections::{HashMap, HashSet};
27
28#[derive(Default)]
29struct JsonFragment {
30    id: String,
31    name: String,
32    json: String,
33}
34
35pub struct OpenRouterChatStream {
36    stream: RefCell<Option<EventSource>>,
37    failure: Option<Error>,
38    finished: RefCell<bool>,
39    finish_reason: RefCell<Option<FinishReason>>,
40    json_fragments: RefCell<HashMap<u32, JsonFragment>>,
41}
42
43impl OpenRouterChatStream {
44    pub fn new(stream: EventSource) -> LlmChatStream<Self> {
45        LlmChatStream::new(OpenRouterChatStream {
46            stream: RefCell::new(Some(stream)),
47            failure: None,
48            finished: RefCell::new(false),
49            finish_reason: RefCell::new(None),
50            json_fragments: RefCell::new(HashMap::new()),
51        })
52    }
53
54    pub fn failed(error: Error) -> LlmChatStream<Self> {
55        LlmChatStream::new(OpenRouterChatStream {
56            stream: RefCell::new(None),
57            failure: Some(error),
58            finished: RefCell::new(false),
59            finish_reason: RefCell::new(None),
60            json_fragments: RefCell::new(HashMap::new()),
61        })
62    }
63}
64
65impl LlmChatStreamState for OpenRouterChatStream {
66    fn failure(&self) -> &Option<Error> {
67        &self.failure
68    }
69
70    fn is_finished(&self) -> bool {
71        *self.finished.borrow()
72    }
73
74    fn set_finished(&self) {
75        *self.finished.borrow_mut() = true;
76    }
77
78    fn stream(&self) -> Ref<'_, Option<EventSource>> {
79        self.stream.borrow()
80    }
81
82    fn stream_mut(&self) -> RefMut<'_, Option<EventSource>> {
83        self.stream.borrow_mut()
84    }
85
86    fn decode_message(&self, raw: &str) -> Result<Option<StreamEvent>, Error> {
87        fn decode_internal_error<S: Into<String>>(message: S) -> Error {
88            Error {
89                code: ErrorCode::InternalError,
90                message: message.into(),
91                provider_error_json: None,
92            }
93        }
94
95        trace!("Received raw stream event: {raw}");
96        if raw.starts_with(": ") {
97            Ok(None) // comment
98        } else {
99            let json: serde_json::Value = serde_json::from_str(raw).map_err(|err| {
100                decode_internal_error(format!("Failed to deserialize stream event: {err}"))
101            })?;
102
103            let typ = json
104                .as_object()
105                .and_then(|obj| obj.get("object"))
106                .and_then(|v| v.as_str());
107            match typ {
108                Some("chat.completion.chunk") => {
109                    let message: ChatCompletionChunk =
110                        serde_json::from_value(json).map_err(|err| {
111                            decode_internal_error(format!("Failed to parse stream event: {err}"))
112                        })?;
113                    if let Some(usage) = message.usage {
114                        let finish_reason = self.finish_reason.borrow();
115                        Ok(Some(StreamEvent::Finish(ResponseMetadata {
116                            finish_reason: *finish_reason,
117                            usage: Some(convert_usage(&usage)),
118                            provider_id: None,
119                            timestamp: Some(message.created.to_string()),
120                            provider_metadata_json: None,
121                        })))
122                    } else if let Some(choice) = message.choices.into_iter().next() {
123                        if let Some(finish_reason) = choice.finish_reason {
124                            *self.finish_reason.borrow_mut() =
125                                Some(convert_finish_reason(&finish_reason));
126                        }
127                        if let Some(error) = choice.error {
128                            Err(Error {
129                                code: error_code_from_status(
130                                    TryInto::<u16>::try_into(error.code)
131                                        .ok()
132                                        .and_then(|code| StatusCode::from_u16(code).ok())
133                                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
134                                ),
135                                message: error.message,
136                                provider_error_json: error
137                                    .metadata
138                                    .map(|value| serde_json::to_string(&value).unwrap()),
139                            })
140                        } else {
141                            let content = choice
142                                .delta
143                                .content
144                                .map(|text| vec![ContentPart::Text(text)]);
145
146                            let mut seen_indices = HashSet::new();
147                            let mut tool_calls = Vec::new();
148                            let mut json_fragments = self.json_fragments.borrow_mut();
149
150                            for tool_call in choice.delta.tool_calls.unwrap_or_default() {
151                                match tool_call {
152                                    client::ToolCall::Function {
153                                        id: Some(id),
154                                        function:
155                                            FunctionCall {
156                                                name: Some(name),
157                                                arguments,
158                                            },
159                                        index: None,
160                                    } => {
161                                        // Full tool call
162                                        tool_calls.push(ToolCall {
163                                            id,
164                                            name,
165                                            arguments_json: arguments,
166                                        });
167                                    }
168                                    client::ToolCall::Function {
169                                        id: Some(id),
170                                        function:
171                                            FunctionCall {
172                                                name: Some(name),
173                                                arguments,
174                                            },
175                                        index: Some(index),
176                                    } => {
177                                        // Beginning of a streamed tool call
178                                        json_fragments.insert(
179                                            index,
180                                            JsonFragment {
181                                                id,
182                                                name,
183                                                json: arguments,
184                                            },
185                                        );
186                                        seen_indices.insert(index);
187                                    }
188                                    client::ToolCall::Function {
189                                        id: _,
190                                        function: FunctionCall { name: _, arguments },
191                                        index: Some(index),
192                                    } => {
193                                        // Fragment
194                                        let fragment = json_fragments.entry(index).or_default();
195                                        fragment.json.push_str(&arguments);
196                                        seen_indices.insert(index);
197                                    }
198                                    _ => {
199                                        return Err(decode_internal_error(format!(
200                                            "Unexpected tool call format: {tool_call:?}"
201                                        )));
202                                    }
203                                }
204                            }
205
206                            let indices =
207                                json_fragments.keys().copied().collect::<Vec<_>>().clone();
208                            for index in indices {
209                                if !seen_indices.contains(&index) {
210                                    // Emitting finished tool call
211                                    let fragment = json_fragments.remove(&index).unwrap();
212                                    tool_calls.push(ToolCall {
213                                        id: fragment.id,
214                                        name: fragment.name,
215                                        arguments_json: fragment.json,
216                                    });
217                                }
218                            }
219
220                            Ok(Some(StreamEvent::Delta(StreamDelta {
221                                content,
222                                tool_calls: if tool_calls.is_empty() {
223                                    None
224                                } else {
225                                    Some(tool_calls)
226                                },
227                            })))
228                        }
229                    } else {
230                        Ok(None)
231                    }
232                }
233                Some(_) => Ok(None),
234                None => Err(decode_internal_error(
235                    "Unexpected stream event format, does not have 'object' field".to_string(),
236                )),
237            }
238        }
239    }
240}
241
242pub struct OpenRouter;
243
244impl OpenRouter {
245    fn request(client: CompletionsApi, request: CompletionsRequest) -> Result<Response, Error> {
246        let response = client.send_messages(request)?;
247        process_response(response)
248    }
249
250    fn streaming_request(
251        client: CompletionsApi,
252        mut request: CompletionsRequest,
253    ) -> LlmChatStream<OpenRouterChatStream> {
254        request.stream = Some(true);
255        match client.stream_send_messages(request) {
256            Ok(stream) => OpenRouterChatStream::new(stream),
257            Err(err) => OpenRouterChatStream::failed(err),
258        }
259    }
260}
261
262impl LlmProvider for OpenRouter {
263    type ChatStream = LlmChatStream<OpenRouterChatStream>;
264    type ProviderConfig = OpenRouterConfig;
265
266    async fn send(
267        provider_config: Self::ProviderConfig,
268        events: Vec<Event>,
269        config: Config,
270    ) -> Result<Response, Error> {
271        let client = CompletionsApi::new(&provider_config);
272        let request = events_to_request(events, config)?;
273        Self::request(client, request)
274    }
275
276    async fn stream(
277        provider_config: Self::ProviderConfig,
278        events: Vec<Event>,
279        config: Config,
280    ) -> ChatStream {
281        ChatStream::new(Self::unwrapped_stream(provider_config, events, config).await)
282    }
283}
284
285impl ExtendedLlmProvider for OpenRouter {
286    async fn unwrapped_stream(
287        provider_config: Self::ProviderConfig,
288        events: Vec<Event>,
289        config: Config,
290    ) -> LlmChatStream<OpenRouterChatStream> {
291        let client = CompletionsApi::new(&provider_config);
292        match events_to_request(events, config) {
293            Ok(request) => Self::streaming_request(client, request),
294            Err(err) => OpenRouterChatStream::failed(err),
295        }
296    }
297
298    fn retry_prompt(
299        original_events: &[Result<Event, Error>],
300        partial_result: &[StreamDelta],
301    ) -> Vec<Event> {
302        let mut extended_events = Vec::new();
303        extended_events.push(Event::Message(Message {
304            role: Role::System,
305            name: None,
306            content: vec![
307                ContentPart::Text(
308                    "You were asked the same question previously, but the response was interrupted before completion. \
309                     Please continue your response from where you left off. \
310                     Do not include the part of the response that was already seen.".to_string()),
311            ],
312        }));
313        extended_events.push(Event::Message(Message {
314            role: Role::User,
315            name: None,
316            content: vec![ContentPart::Text(
317                "Here is the original question:".to_string(),
318            )],
319        }));
320        extended_events.extend(
321            original_events
322                .iter()
323                .filter_map(|event| event.as_ref().ok().cloned()),
324        );
325
326        let mut partial_result_as_content = Vec::new();
327        for delta in partial_result {
328            if let Some(contents) = &delta.content {
329                partial_result_as_content.extend_from_slice(contents);
330            }
331            if let Some(tool_calls) = &delta.tool_calls {
332                for tool_call in tool_calls {
333                    partial_result_as_content.push(ContentPart::Text(format!(
334                        "<tool-call id=\"{}\" name=\"{}\" arguments=\"{}\"/>",
335                        tool_call.id, tool_call.name, tool_call.arguments_json,
336                    )));
337                }
338            }
339        }
340
341        extended_events.push(Event::Message(Message {
342            role: Role::User,
343            name: None,
344            content: vec![ContentPart::Text(
345                "Here is the partial response that was successfully received:".to_string(),
346            )]
347            .into_iter()
348            .chain(partial_result_as_content)
349            .collect(),
350        }));
351        extended_events
352    }
353
354    fn subscribe(stream: &Self::ChatStream) -> Pollable {
355        stream.subscribe()
356    }
357}
358
359pub type DurableOpenRouter = DurableLLM<OpenRouter>;