1use futures::Stream;
2use hyperinfer_core::types::{ChatMessage, Choice, MessageRole, Usage};
3use hyperinfer_core::{ChatChunk, ChatRequest, ChatResponse, HyperInferError};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use std::pin::Pin;
7
8pub(crate) fn drain_lines(raw_buf: &mut Vec<u8>, lines: &mut Vec<String>) {
20 while let Some(pos) = raw_buf.iter().position(|&b| b == b'\n') {
21 let line_bytes = &raw_buf[..pos];
22 let line_bytes = line_bytes.strip_suffix(b"\r").unwrap_or(line_bytes);
23 lines.push(String::from_utf8_lossy(line_bytes).into_owned());
24 raw_buf.drain(..=pos);
25 }
26}
27
28pub struct HttpCaller {
29 client: Client,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct OpenAiResponse {
34 pub id: String,
35 pub choices: Vec<OpenAiChoice>,
36 pub usage: UsageDetail,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct OpenAiChoice {
41 pub index: u32,
42 pub message: Message,
43 pub finish_reason: Option<String>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Message {
48 pub role: String,
49 pub content: String,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct UsageDetail {
54 pub prompt_tokens: u32,
55 pub completion_tokens: u32,
56 pub total_tokens: u32,
57}
58
59impl HttpCaller {
60 pub fn new() -> Result<Self, reqwest::Error> {
61 let client = Client::builder()
62 .timeout(std::time::Duration::from_secs(60))
63 .build()?;
64 Ok(Self { client })
65 }
66
67 pub async fn call_openai(
68 &self,
69 model: &str,
70 api_key: &str,
71 request: &ChatRequest,
72 ) -> Result<ChatResponse, HyperInferError> {
73 let url = "https://api.openai.com/v1/chat/completions".to_string();
74
75 let mut body = serde_json::json!({
76 "model": model,
77 "messages": request.messages,
78 "temperature": request.temperature,
79 "max_tokens": request.max_tokens,
80 });
81 if let Some(stop) = &request.stop {
82 body["stop"] = serde_json::json!(stop);
83 }
84
85 let response = self
86 .client
87 .post(&url)
88 .header("Authorization", format!("Bearer {}", api_key))
89 .header("Content-Type", "application/json")
90 .json(&body)
91 .send()
92 .await?;
93
94 if !response.status().is_success() {
95 let status = response.status();
96 let error_text = response.text().await.unwrap_or_default();
97 return Err(HyperInferError::ApiError {
98 status: status.as_u16(),
99 message: error_text,
100 });
101 }
102
103 let data: OpenAiResponse = response.json().await?;
104
105 Ok(ChatResponse {
106 id: data.id,
107 model: model.to_string(),
108 choices: data
109 .choices
110 .into_iter()
111 .map(|c| Choice {
112 index: c.index,
113 message: ChatMessage {
114 role: match c.message.role.as_str() {
115 "assistant" => MessageRole::Assistant,
116 "user" => MessageRole::User,
117 "system" => MessageRole::System,
118 other => {
119 tracing::warn!(
120 "Unknown OpenAI role '{}', defaulting to Assistant",
121 other
122 );
123 MessageRole::Assistant
124 }
125 },
126 content: c.message.content,
127 },
128 finish_reason: c.finish_reason,
129 })
130 .collect(),
131 usage: Usage {
132 input_tokens: data.usage.prompt_tokens,
133 output_tokens: data.usage.completion_tokens,
134 },
135 })
136 }
137
138 pub async fn call_anthropic(
139 &self,
140 model: &str,
141 api_key: &str,
142 request: &ChatRequest,
143 ) -> Result<ChatResponse, HyperInferError> {
144 let url = "https://api.anthropic.com/v1/messages";
145
146 let system_messages: Vec<_> = request
147 .messages
148 .iter()
149 .filter(|m| m.role == MessageRole::System)
150 .map(|m| m.content.as_str())
151 .collect();
152
153 let system = if system_messages.is_empty() {
154 None
155 } else {
156 Some(system_messages.join("\n"))
157 };
158
159 let messages: Vec<_> = request
160 .messages
161 .iter()
162 .filter(|m| m.role != MessageRole::System)
163 .map(|m| {
164 serde_json::json!({
165 "role": match m.role {
166 MessageRole::User => "user",
167 MessageRole::Assistant => "assistant",
168 _ => "user",
169 },
170 "content": m.content
171 })
172 })
173 .collect();
174
175 let mut body = serde_json::json!({
176 "model": model,
177 "messages": messages,
178 "max_tokens": request.max_tokens.unwrap_or(1024),
179 });
180
181 if let Some(s) = system {
182 body["system"] = serde_json::json!(s);
183 }
184 if let Some(t) = request.temperature {
185 body["temperature"] = serde_json::json!(t);
186 }
187 if let Some(stop) = &request.stop {
188 body["stop_sequences"] = serde_json::json!(stop);
189 }
190
191 let response = self
192 .client
193 .post(url)
194 .header("x-api-key", api_key)
195 .header("anthropic-version", "2023-06-01")
196 .header("Content-Type", "application/json")
197 .json(&body)
198 .send()
199 .await?;
200
201 if !response.status().is_success() {
202 let status = response.status();
203 let error_text = response.text().await.unwrap_or_default();
204 return Err(HyperInferError::ApiError {
205 status: status.as_u16(),
206 message: error_text,
207 });
208 }
209
210 #[derive(Deserialize)]
211 struct AnthropicResponse {
212 id: String,
213 content: Vec<ContentBlock>,
214 usage: AnthropicUsageDetail,
215 }
216
217 #[derive(Deserialize)]
218 struct ContentBlock {
219 text: Option<String>,
220 }
221
222 #[derive(Deserialize)]
223 struct AnthropicUsageDetail {
224 input_tokens: u32,
225 output_tokens: u32,
226 }
227
228 let data: AnthropicResponse = response.json().await?;
229
230 let content = data
231 .content
232 .into_iter()
233 .filter_map(|b| b.text)
234 .collect::<Vec<_>>()
235 .join("\n");
236
237 Ok(ChatResponse {
238 id: data.id,
239 model: model.to_string(),
240 choices: vec![Choice {
241 index: 0,
242 message: ChatMessage {
243 role: MessageRole::Assistant,
244 content,
245 },
246 finish_reason: Some("stop".to_string()),
247 }],
248 usage: Usage {
249 input_tokens: data.usage.input_tokens,
250 output_tokens: data.usage.output_tokens,
251 },
252 })
253 }
254
255 pub fn stream_openai(
260 &self,
261 model: &str,
262 api_key: &str,
263 request: &ChatRequest,
264 ) -> Pin<Box<dyn Stream<Item = Result<ChatChunk, HyperInferError>> + Send + 'static>> {
265 use futures::StreamExt;
266
267 let url = "https://api.openai.com/v1/chat/completions".to_string();
268 let model = model.to_string();
269 let api_key = api_key.to_string();
270
271 let mut body = serde_json::json!({
272 "model": model,
273 "messages": request.messages,
274 "temperature": request.temperature,
275 "max_tokens": request.max_tokens,
276 "stream": true,
277 "stream_options": { "include_usage": true },
278 });
279 if let Some(ref stop) = request.stop {
280 body["stop"] = serde_json::json!(stop);
281 }
282
283 let client = self.client.clone();
284
285 let stream = async_stream::try_stream! {
286 let response = client
287 .post(&url)
288 .header("Authorization", format!("Bearer {}", api_key))
289 .header("Content-Type", "application/json")
290 .json(&body)
291 .send()
292 .await?;
293
294 if !response.status().is_success() {
295 let status = response.status();
296 let error_text = response.text().await.unwrap_or_default();
297 Err(HyperInferError::ApiError {
298 status: status.as_u16(),
299 message: error_text,
300 })?;
301 return;
302 }
303
304 let mut byte_stream = response.bytes_stream();
305
306 let mut raw_buf: Vec<u8> = Vec::new();
310
311 while let Some(bytes) = byte_stream.next().await {
312 let bytes = bytes?;
313 raw_buf.extend_from_slice(&bytes);
314
315 let mut lines = Vec::new();
316 drain_lines(&mut raw_buf, &mut lines);
317
318 for line in lines {
319 if line.is_empty() || line.starts_with(':') {
320 continue;
321 }
322 let data = if let Some(d) = line.strip_prefix("data: ") { d.to_owned() } else { continue };
323 if data == "[DONE]" {
324 return;
325 }
326
327 #[derive(Deserialize)]
328 struct StreamChoice {
329 delta: DeltaContent,
330 finish_reason: Option<String>,
331 }
332 #[derive(Deserialize)]
333 struct DeltaContent {
334 #[serde(default)]
335 content: String,
336 }
337 #[derive(Deserialize)]
338 struct StreamEvent {
339 #[serde(default)]
340 id: String,
341 #[serde(default)]
342 model: String,
343 #[serde(default)]
344 choices: Vec<StreamChoice>,
345 usage: Option<OpenAiStreamUsage>,
346 }
347 #[derive(Deserialize)]
348 struct OpenAiStreamUsage {
349 prompt_tokens: u32,
350 completion_tokens: u32,
351 }
352 #[derive(Deserialize)]
355 struct OpenAiStreamError {
356 error: OpenAiErrorDetail,
357 }
358 #[derive(Deserialize)]
359 struct OpenAiErrorDetail {
360 message: String,
361 }
362
363 if let Ok(err_event) = serde_json::from_str::<OpenAiStreamError>(&data) {
366 Err(HyperInferError::StreamParse {
367 message: err_event.error.message,
368 raw: data.clone(),
369 })?;
370 return;
371 }
372
373 match serde_json::from_str::<StreamEvent>(&data) {
374 Ok(event) => {
375 let finish_reason = event.choices.first()
376 .and_then(|c| c.finish_reason.clone());
377 let delta = event.choices.first()
378 .map(|c| c.delta.content.clone())
379 .unwrap_or_default();
380 let usage = event.usage.map(|u| Usage {
381 input_tokens: u.prompt_tokens,
382 output_tokens: u.completion_tokens,
383 });
384
385 yield ChatChunk {
386 id: event.id,
387 model: event.model,
388 delta,
389 finish_reason,
390 usage,
391 };
392 }
393 Err(parse_err) => {
394 Err(HyperInferError::StreamParse {
395 message: parse_err.to_string(),
396 raw: data.clone(),
397 })?;
398 return;
399 }
400 }
401 }
402 }
403 };
404
405 Box::pin(stream)
406 }
407
408 pub fn stream_anthropic(
413 &self,
414 model: &str,
415 api_key: &str,
416 request: &ChatRequest,
417 ) -> Pin<Box<dyn Stream<Item = Result<ChatChunk, HyperInferError>> + Send + 'static>> {
418 use futures::StreamExt;
419
420 let url = "https://api.anthropic.com/v1/messages";
421 let model = model.to_string();
422 let api_key = api_key.to_string();
423
424 let system_messages: Vec<_> = request
425 .messages
426 .iter()
427 .filter(|m| m.role == MessageRole::System)
428 .map(|m| m.content.as_str())
429 .collect();
430
431 let system = if system_messages.is_empty() {
432 None
433 } else {
434 Some(system_messages.join("\n"))
435 };
436
437 let messages: Vec<_> = request
438 .messages
439 .iter()
440 .filter(|m| m.role != MessageRole::System)
441 .map(|m| {
442 serde_json::json!({
443 "role": match m.role {
444 MessageRole::User => "user",
445 MessageRole::Assistant => "assistant",
446 _ => "user",
447 },
448 "content": m.content,
449 })
450 })
451 .collect();
452
453 let mut body = serde_json::json!({
454 "model": model,
455 "messages": messages,
456 "max_tokens": request.max_tokens.unwrap_or(1024),
457 "stream": true,
458 });
459 if let Some(s) = system {
460 body["system"] = serde_json::json!(s);
461 }
462 if let Some(t) = request.temperature {
463 body["temperature"] = serde_json::json!(t);
464 }
465 if let Some(ref stop) = request.stop {
466 body["stop_sequences"] = serde_json::json!(stop);
467 }
468
469 let client = self.client.clone();
470
471 let stream = async_stream::try_stream! {
472 let response = client
473 .post(url)
474 .header("x-api-key", &api_key)
475 .header("anthropic-version", "2023-06-01")
476 .header("Content-Type", "application/json")
477 .json(&body)
478 .send()
479 .await?;
480
481 if !response.status().is_success() {
482 let status = response.status();
483 let error_text = response.text().await.unwrap_or_default();
484 Err(HyperInferError::ApiError {
485 status: status.as_u16(),
486 message: error_text,
487 })?;
488 return;
489 }
490
491 let mut byte_stream = response.bytes_stream();
492 let mut raw_buf: Vec<u8> = Vec::new();
496 let mut stream_id = String::new();
498 let mut cached_input_tokens: u32 = 0;
501
502 while let Some(bytes) = byte_stream.next().await {
503 let bytes = bytes?;
504 raw_buf.extend_from_slice(&bytes);
505
506 let mut lines = Vec::new();
507 drain_lines(&mut raw_buf, &mut lines);
508
509 for line in lines {
510 if line.is_empty() || line.starts_with(':') {
511 continue;
512 }
513 let data = if let Some(d) = line.strip_prefix("data: ") { d.to_owned() } else { continue };
514
515 #[derive(Deserialize)]
516 struct AnthropicEvent {
517 #[serde(rename = "type")]
518 event_type: String,
519 message: Option<AnthropicMessage>,
521 delta: Option<AnthropicDelta>,
523 usage: Option<AnthropicStreamUsage>,
525 }
526 #[derive(Deserialize)]
527 struct AnthropicMessage {
528 id: String,
529 usage: Option<AnthropicStreamUsage>,
530 }
531 #[derive(Deserialize)]
532 struct AnthropicDelta {
533 #[serde(rename = "type")]
534 delta_type: String,
535 #[serde(default)]
536 text: String,
537 stop_reason: Option<String>,
538 }
539 #[derive(Deserialize)]
540 struct AnthropicStreamUsage {
541 input_tokens: Option<u32>,
542 output_tokens: Option<u32>,
543 }
544
545 #[derive(Deserialize)]
548 struct AnthropicStreamError {
549 error: AnthropicErrorDetail,
550 }
551 #[derive(Deserialize)]
552 struct AnthropicErrorDetail {
553 message: String,
554 }
555
556 match serde_json::from_str::<AnthropicEvent>(&data) {
557 Ok(event) => match event.event_type.as_str() {
558 "error" => {
559 let msg = serde_json::from_str::<AnthropicStreamError>(&data)
562 .map(|e| e.error.message)
563 .unwrap_or_else(|_| data.clone());
564 Err(HyperInferError::StreamParse {
565 message: msg,
566 raw: data.clone(),
567 })?;
568 return;
569 }
570 "message_start" => {
571 if let Some(msg) = event.message {
572 stream_id = msg.id;
573 if let Some(u) = msg.usage {
574 cached_input_tokens = u.input_tokens.unwrap_or(0);
575 }
576 }
577 }
578 "content_block_delta" => {
579 if let Some(delta) = event.delta {
580 if delta.delta_type == "text_delta" {
581 yield ChatChunk {
582 id: stream_id.clone(),
583 model: model.clone(),
584 delta: delta.text,
585 finish_reason: None,
586 usage: None,
587 };
588 }
589 }
590 }
591 "message_delta" => {
592 let finish_reason = event.delta
595 .as_ref()
596 .and_then(|d| d.stop_reason.clone());
597 let usage = event.usage.map(|u| Usage {
598 input_tokens: cached_input_tokens,
599 output_tokens: u.output_tokens.unwrap_or(0),
600 });
601 yield ChatChunk {
602 id: stream_id.clone(),
603 model: model.clone(),
604 delta: String::new(),
605 finish_reason,
606 usage,
607 };
608 }
609 "message_stop" => return,
610 _ => {}
611 },
612 Err(parse_err) => {
613 Err(HyperInferError::StreamParse {
614 message: parse_err.to_string(),
615 raw: data.clone(),
616 })?;
617 return;
618 }
619 }
620 }
621 }
622 };
623
624 Box::pin(stream)
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn test_http_caller_new() {
634 let result = HttpCaller::new();
635 assert!(result.is_ok());
636 }
637
638 #[test]
639 fn test_openai_response_deserialization() {
640 let json = r#"{
641 "id": "chatcmpl-123",
642 "choices": [{
643 "index": 0,
644 "message": {
645 "role": "assistant",
646 "content": "Hello!"
647 },
648 "finish_reason": "stop"
649 }],
650 "usage": {
651 "prompt_tokens": 10,
652 "completion_tokens": 5,
653 "total_tokens": 15
654 }
655 }"#;
656
657 let response: OpenAiResponse = serde_json::from_str(json).unwrap();
658 assert_eq!(response.id, "chatcmpl-123");
659 assert_eq!(response.choices.len(), 1);
660 assert_eq!(response.choices[0].message.content, "Hello!");
661 assert_eq!(response.usage.total_tokens, 15);
662 }
663
664 #[test]
665 fn test_openai_choice_deserialization() {
666 let json = r#"{
667 "index": 0,
668 "message": {
669 "role": "user",
670 "content": "Test message"
671 },
672 "finish_reason": "length"
673 }"#;
674
675 let choice: OpenAiChoice = serde_json::from_str(json).unwrap();
676 assert_eq!(choice.index, 0);
677 assert_eq!(choice.message.role, "user");
678 assert_eq!(choice.message.content, "Test message");
679 assert_eq!(choice.finish_reason, Some("length".to_string()));
680 }
681
682 #[test]
683 fn test_usage_deserialization() {
684 let json = r#"{
685 "prompt_tokens": 100,
686 "completion_tokens": 50,
687 "total_tokens": 150
688 }"#;
689
690 let usage: UsageDetail = serde_json::from_str(json).unwrap();
691 assert_eq!(usage.prompt_tokens, 100);
692 assert_eq!(usage.completion_tokens, 50);
693 assert_eq!(usage.total_tokens, 150);
694 }
695
696 #[test]
697 fn test_message_serialization() {
698 let message = Message {
699 role: "assistant".to_string(),
700 content: "Response text".to_string(),
701 };
702
703 let json = serde_json::to_string(&message).unwrap();
704 assert!(json.contains("assistant"));
705 assert!(json.contains("Response text"));
706 }
707
708 #[test]
709 fn test_openai_response_clone() {
710 let response = OpenAiResponse {
711 id: "test-id".to_string(),
712 choices: vec![],
713 usage: UsageDetail {
714 prompt_tokens: 10,
715 completion_tokens: 5,
716 total_tokens: 15,
717 },
718 };
719
720 let cloned = response.clone();
721 assert_eq!(response.id, cloned.id);
722 assert_eq!(response.usage.total_tokens, cloned.usage.total_tokens);
723 }
724
725 #[test]
726 fn test_openai_choice_with_no_finish_reason() {
727 let json = r#"{
728 "index": 1,
729 "message": {
730 "role": "assistant",
731 "content": "Partial response"
732 },
733 "finish_reason": null
734 }"#;
735
736 let choice: OpenAiChoice = serde_json::from_str(json).unwrap();
737 assert_eq!(choice.index, 1);
738 assert_eq!(choice.finish_reason, None);
739 }
740
741 #[tokio::test]
742 async fn test_call_openai_request_structure() {
743 let request = ChatRequest {
744 model: "gpt-4".to_string(),
745 messages: vec![ChatMessage {
746 role: MessageRole::User,
747 content: "Hello".to_string(),
748 }],
749 temperature: Some(0.7),
750 max_tokens: Some(100),
751 stream: None,
752 stop: None,
753 };
754
755 let body = serde_json::json!({
758 "model": "gpt-4",
759 "messages": request.messages,
760 "temperature": request.temperature,
761 "max_tokens": request.max_tokens,
762 });
763
764 assert_eq!(body["model"], "gpt-4");
765 assert_eq!(body["temperature"], 0.7);
766 assert_eq!(body["max_tokens"], 100);
767 }
768
769 #[tokio::test]
770 async fn test_call_anthropic_request_structure() {
771 let request = ChatRequest {
772 model: "claude-3".to_string(),
773 messages: vec![
774 ChatMessage {
775 role: MessageRole::System,
776 content: "You are helpful".to_string(),
777 },
778 ChatMessage {
779 role: MessageRole::User,
780 content: "Hello".to_string(),
781 },
782 ],
783 temperature: Some(0.5),
784 max_tokens: Some(200),
785 stream: None,
786 stop: None,
787 };
788
789 let system = request
791 .messages
792 .iter()
793 .find(|m| m.role == MessageRole::System)
794 .map(|m| m.content.clone());
795
796 assert_eq!(system, Some("You are helpful".to_string()));
797
798 let messages: Vec<_> = request
800 .messages
801 .iter()
802 .filter(|m| m.role != MessageRole::System)
803 .collect();
804
805 assert_eq!(messages.len(), 1);
806 assert_eq!(messages[0].content, "Hello");
807 }
808
809 fn feed_chunks(chunks: &[&[u8]]) -> (Vec<String>, Vec<u8>) {
814 let mut raw_buf: Vec<u8> = Vec::new();
815 let mut all_lines: Vec<String> = Vec::new();
816 for chunk in chunks {
817 raw_buf.extend_from_slice(chunk);
818 drain_lines(&mut raw_buf, &mut all_lines);
819 }
820 (all_lines, raw_buf)
821 }
822
823 #[test]
824 fn test_drain_lines_single_chunk() {
825 let (lines, remainder) = feed_chunks(&[b"data: hello\ndata: world\n"]);
826 assert_eq!(lines, vec!["data: hello", "data: world"]);
827 assert!(remainder.is_empty());
828 }
829
830 #[test]
831 fn test_drain_lines_crlf_endings() {
832 let (lines, remainder) = feed_chunks(&[b"data: hello\r\ndata: world\r\n"]);
833 assert_eq!(lines, vec!["data: hello", "data: world"]);
834 assert!(remainder.is_empty());
835 }
836
837 #[test]
838 fn test_drain_lines_incomplete_line_buffered() {
839 let (lines, remainder) = feed_chunks(&[b"data: hello\n", b"data: partial"]);
841 assert_eq!(lines, vec!["data: hello"]);
842 assert_eq!(remainder, b"data: partial");
843 }
844
845 #[test]
846 fn test_drain_lines_multibyte_split_across_chunks() {
847 let chunk1: &[u8] = b"data: caf\xc3"; let chunk2: &[u8] = b"\xa9\ndata: done\n"; let (lines, remainder) = feed_chunks(&[chunk1, chunk2]);
856 assert_eq!(lines[0], "data: café");
857 assert_eq!(lines[1], "data: done");
858 assert!(remainder.is_empty());
859 }
860
861 #[test]
862 fn test_drain_lines_three_byte_split_across_three_chunks() {
863 let chunk1: &[u8] = b"data: \xe4";
866 let chunk2: &[u8] = b"\xb8";
867 let chunk3: &[u8] = b"\xad\n";
868 let (lines, remainder) = feed_chunks(&[chunk1, chunk2, chunk3]);
869 assert_eq!(lines, vec!["data: 中"]);
870 assert!(remainder.is_empty());
871 }
872
873 #[test]
874 fn test_drain_lines_empty_lines_preserved() {
875 let (lines, _) = feed_chunks(&[b"data: hello\n\ndata: world\n"]);
878 assert_eq!(lines, vec!["data: hello", "", "data: world"]);
879 }
880
881 #[test]
882 fn test_drain_lines_no_newline_nothing_emitted() {
883 let (lines, remainder) = feed_chunks(&[b"data: no newline yet"]);
884 assert!(lines.is_empty());
885 assert_eq!(remainder, b"data: no newline yet");
886 }
887}