1mod client;
2pub mod config;
3mod conversions;
4
5pub use config::OpenRouterConfig;
6#[cfg(feature = "golem")]
7pub use config::OpenRouterHostConfig;
8
9use crate::client::{ChatCompletionChunk, CompletionsApi, CompletionsRequest, FunctionCall};
10use crate::conversions::{
11 convert_finish_reason, convert_usage, events_to_request, process_response,
12};
13use golem_ai_llm::chat_stream::{LlmChatStream, LlmChatStreamState};
14use golem_ai_llm::durability::{DurableLLM, ExtendedLlmProvider};
15use golem_ai_llm::error::error_code_from_status;
16use golem_ai_llm::event_source::EventSource;
17use golem_ai_llm::model::{
18 ChatStream, Config, ContentPart, Error, ErrorCode, Event, FinishReason, Message, Response,
19 ResponseMetadata, Role, StreamDelta, StreamEvent, ToolCall,
20};
21use golem_ai_llm::wasi_compat::Pollable;
22use golem_ai_llm::LlmProvider;
23use golem_wasi_http::StatusCode;
24use log::trace;
25use std::cell::{Ref, RefCell, RefMut};
26use std::collections::{HashMap, HashSet};
27
28#[derive(Default)]
29struct JsonFragment {
30 id: String,
31 name: String,
32 json: String,
33}
34
35pub struct OpenRouterChatStream {
36 stream: RefCell<Option<EventSource>>,
37 failure: Option<Error>,
38 finished: RefCell<bool>,
39 finish_reason: RefCell<Option<FinishReason>>,
40 json_fragments: RefCell<HashMap<u32, JsonFragment>>,
41}
42
43impl OpenRouterChatStream {
44 pub fn new(stream: EventSource) -> LlmChatStream<Self> {
45 LlmChatStream::new(OpenRouterChatStream {
46 stream: RefCell::new(Some(stream)),
47 failure: None,
48 finished: RefCell::new(false),
49 finish_reason: RefCell::new(None),
50 json_fragments: RefCell::new(HashMap::new()),
51 })
52 }
53
54 pub fn failed(error: Error) -> LlmChatStream<Self> {
55 LlmChatStream::new(OpenRouterChatStream {
56 stream: RefCell::new(None),
57 failure: Some(error),
58 finished: RefCell::new(false),
59 finish_reason: RefCell::new(None),
60 json_fragments: RefCell::new(HashMap::new()),
61 })
62 }
63}
64
65impl LlmChatStreamState for OpenRouterChatStream {
66 fn failure(&self) -> &Option<Error> {
67 &self.failure
68 }
69
70 fn is_finished(&self) -> bool {
71 *self.finished.borrow()
72 }
73
74 fn set_finished(&self) {
75 *self.finished.borrow_mut() = true;
76 }
77
78 fn stream(&self) -> Ref<'_, Option<EventSource>> {
79 self.stream.borrow()
80 }
81
82 fn stream_mut(&self) -> RefMut<'_, Option<EventSource>> {
83 self.stream.borrow_mut()
84 }
85
86 fn decode_message(&self, raw: &str) -> Result<Option<StreamEvent>, Error> {
87 fn decode_internal_error<S: Into<String>>(message: S) -> Error {
88 Error {
89 code: ErrorCode::InternalError,
90 message: message.into(),
91 provider_error_json: None,
92 }
93 }
94
95 trace!("Received raw stream event: {raw}");
96 if raw.starts_with(": ") {
97 Ok(None) } else {
99 let json: serde_json::Value = serde_json::from_str(raw).map_err(|err| {
100 decode_internal_error(format!("Failed to deserialize stream event: {err}"))
101 })?;
102
103 let typ = json
104 .as_object()
105 .and_then(|obj| obj.get("object"))
106 .and_then(|v| v.as_str());
107 match typ {
108 Some("chat.completion.chunk") => {
109 let message: ChatCompletionChunk =
110 serde_json::from_value(json).map_err(|err| {
111 decode_internal_error(format!("Failed to parse stream event: {err}"))
112 })?;
113 if let Some(usage) = message.usage {
114 let finish_reason = self.finish_reason.borrow();
115 Ok(Some(StreamEvent::Finish(ResponseMetadata {
116 finish_reason: *finish_reason,
117 usage: Some(convert_usage(&usage)),
118 provider_id: None,
119 timestamp: Some(message.created.to_string()),
120 provider_metadata_json: None,
121 })))
122 } else if let Some(choice) = message.choices.into_iter().next() {
123 if let Some(finish_reason) = choice.finish_reason {
124 *self.finish_reason.borrow_mut() =
125 Some(convert_finish_reason(&finish_reason));
126 }
127 if let Some(error) = choice.error {
128 Err(Error {
129 code: error_code_from_status(
130 TryInto::<u16>::try_into(error.code)
131 .ok()
132 .and_then(|code| StatusCode::from_u16(code).ok())
133 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
134 ),
135 message: error.message,
136 provider_error_json: error
137 .metadata
138 .map(|value| serde_json::to_string(&value).unwrap()),
139 })
140 } else {
141 let content = choice
142 .delta
143 .content
144 .map(|text| vec![ContentPart::Text(text)]);
145
146 let mut seen_indices = HashSet::new();
147 let mut tool_calls = Vec::new();
148 let mut json_fragments = self.json_fragments.borrow_mut();
149
150 for tool_call in choice.delta.tool_calls.unwrap_or_default() {
151 match tool_call {
152 client::ToolCall::Function {
153 id: Some(id),
154 function:
155 FunctionCall {
156 name: Some(name),
157 arguments,
158 },
159 index: None,
160 } => {
161 tool_calls.push(ToolCall {
163 id,
164 name,
165 arguments_json: arguments,
166 });
167 }
168 client::ToolCall::Function {
169 id: Some(id),
170 function:
171 FunctionCall {
172 name: Some(name),
173 arguments,
174 },
175 index: Some(index),
176 } => {
177 json_fragments.insert(
179 index,
180 JsonFragment {
181 id,
182 name,
183 json: arguments,
184 },
185 );
186 seen_indices.insert(index);
187 }
188 client::ToolCall::Function {
189 id: _,
190 function: FunctionCall { name: _, arguments },
191 index: Some(index),
192 } => {
193 let fragment = json_fragments.entry(index).or_default();
195 fragment.json.push_str(&arguments);
196 seen_indices.insert(index);
197 }
198 _ => {
199 return Err(decode_internal_error(format!(
200 "Unexpected tool call format: {tool_call:?}"
201 )));
202 }
203 }
204 }
205
206 let indices =
207 json_fragments.keys().copied().collect::<Vec<_>>().clone();
208 for index in indices {
209 if !seen_indices.contains(&index) {
210 let fragment = json_fragments.remove(&index).unwrap();
212 tool_calls.push(ToolCall {
213 id: fragment.id,
214 name: fragment.name,
215 arguments_json: fragment.json,
216 });
217 }
218 }
219
220 Ok(Some(StreamEvent::Delta(StreamDelta {
221 content,
222 tool_calls: if tool_calls.is_empty() {
223 None
224 } else {
225 Some(tool_calls)
226 },
227 })))
228 }
229 } else {
230 Ok(None)
231 }
232 }
233 Some(_) => Ok(None),
234 None => Err(decode_internal_error(
235 "Unexpected stream event format, does not have 'object' field".to_string(),
236 )),
237 }
238 }
239 }
240}
241
242pub struct OpenRouter;
243
244impl OpenRouter {
245 fn request(client: CompletionsApi, request: CompletionsRequest) -> Result<Response, Error> {
246 let response = client.send_messages(request)?;
247 process_response(response)
248 }
249
250 fn streaming_request(
251 client: CompletionsApi,
252 mut request: CompletionsRequest,
253 ) -> LlmChatStream<OpenRouterChatStream> {
254 request.stream = Some(true);
255 match client.stream_send_messages(request) {
256 Ok(stream) => OpenRouterChatStream::new(stream),
257 Err(err) => OpenRouterChatStream::failed(err),
258 }
259 }
260}
261
262impl LlmProvider for OpenRouter {
263 type ChatStream = LlmChatStream<OpenRouterChatStream>;
264 type ProviderConfig = OpenRouterConfig;
265
266 async fn send(
267 provider_config: Self::ProviderConfig,
268 events: Vec<Event>,
269 config: Config,
270 ) -> Result<Response, Error> {
271 let client = CompletionsApi::new(&provider_config);
272 let request = events_to_request(events, config)?;
273 Self::request(client, request)
274 }
275
276 async fn stream(
277 provider_config: Self::ProviderConfig,
278 events: Vec<Event>,
279 config: Config,
280 ) -> ChatStream {
281 ChatStream::new(Self::unwrapped_stream(provider_config, events, config).await)
282 }
283}
284
285impl ExtendedLlmProvider for OpenRouter {
286 async fn unwrapped_stream(
287 provider_config: Self::ProviderConfig,
288 events: Vec<Event>,
289 config: Config,
290 ) -> LlmChatStream<OpenRouterChatStream> {
291 let client = CompletionsApi::new(&provider_config);
292 match events_to_request(events, config) {
293 Ok(request) => Self::streaming_request(client, request),
294 Err(err) => OpenRouterChatStream::failed(err),
295 }
296 }
297
298 fn retry_prompt(
299 original_events: &[Result<Event, Error>],
300 partial_result: &[StreamDelta],
301 ) -> Vec<Event> {
302 let mut extended_events = Vec::new();
303 extended_events.push(Event::Message(Message {
304 role: Role::System,
305 name: None,
306 content: vec![
307 ContentPart::Text(
308 "You were asked the same question previously, but the response was interrupted before completion. \
309 Please continue your response from where you left off. \
310 Do not include the part of the response that was already seen.".to_string()),
311 ],
312 }));
313 extended_events.push(Event::Message(Message {
314 role: Role::User,
315 name: None,
316 content: vec![ContentPart::Text(
317 "Here is the original question:".to_string(),
318 )],
319 }));
320 extended_events.extend(
321 original_events
322 .iter()
323 .filter_map(|event| event.as_ref().ok().cloned()),
324 );
325
326 let mut partial_result_as_content = Vec::new();
327 for delta in partial_result {
328 if let Some(contents) = &delta.content {
329 partial_result_as_content.extend_from_slice(contents);
330 }
331 if let Some(tool_calls) = &delta.tool_calls {
332 for tool_call in tool_calls {
333 partial_result_as_content.push(ContentPart::Text(format!(
334 "<tool-call id=\"{}\" name=\"{}\" arguments=\"{}\"/>",
335 tool_call.id, tool_call.name, tool_call.arguments_json,
336 )));
337 }
338 }
339 }
340
341 extended_events.push(Event::Message(Message {
342 role: Role::User,
343 name: None,
344 content: vec![ContentPart::Text(
345 "Here is the partial response that was successfully received:".to_string(),
346 )]
347 .into_iter()
348 .chain(partial_result_as_content)
349 .collect(),
350 }));
351 extended_events
352 }
353
354 fn subscribe(stream: &Self::ChatStream) -> Pollable {
355 stream.subscribe()
356 }
357}
358
359pub type DurableOpenRouter = DurableLLM<OpenRouter>;