use crate::core::{GenericProvider, HttpClient, Protocol};
use crate::protocols::OpenAIProtocol;
use crate::error::LlmConnectorError;
use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct VolcengineProtocol {
inner: OpenAIProtocol,
}
impl VolcengineProtocol {
pub fn new(api_key: &str) -> Self {
Self {
inner: OpenAIProtocol::new(api_key),
}
}
}
#[async_trait::async_trait]
impl Protocol for VolcengineProtocol {
type Request = <OpenAIProtocol as Protocol>::Request;
type Response = <OpenAIProtocol as Protocol>::Response;
fn name(&self) -> &str {
"volcengine"
}
fn chat_endpoint(&self, base_url: &str) -> String {
format!("{}/api/v3/chat/completions", base_url.trim_end_matches('/'))
}
fn models_endpoint(&self, base_url: &str) -> Option<String> {
Some(format!("{}/api/v3/models", base_url.trim_end_matches('/')))
}
fn build_request(&self, request: &crate::types::ChatRequest) -> Result<Self::Request, LlmConnectorError> {
self.inner.build_request(request)
}
fn parse_response(&self, response: &str) -> Result<crate::types::ChatResponse, LlmConnectorError> {
self.inner.parse_response(response)
}
fn map_error(&self, status: u16, message: &str) -> LlmConnectorError {
self.inner.map_error(status, message)
}
fn auth_headers(&self) -> Vec<(String, String)> {
self.inner.auth_headers()
}
#[cfg(feature = "streaming")]
async fn parse_stream_response(&self, response: reqwest::Response) -> Result<crate::types::ChatStream, LlmConnectorError> {
self.inner.parse_stream_response(response).await
}
}
pub type VolcengineProvider = GenericProvider<VolcengineProtocol>;
pub fn volcengine(api_key: &str) -> Result<VolcengineProvider, LlmConnectorError> {
volcengine_with_config(api_key, None, None, None)
}
pub fn volcengine_with_config(
api_key: &str,
base_url: Option<&str>,
timeout_secs: Option<u64>,
proxy: Option<&str>,
) -> Result<VolcengineProvider, LlmConnectorError> {
let protocol = VolcengineProtocol::new(api_key);
let client = HttpClient::with_config(
base_url.unwrap_or("https://ark.cn-beijing.volces.com"),
timeout_secs,
proxy,
)?;
let auth_headers: HashMap<String, String> = protocol.auth_headers().into_iter().collect();
let client = client.with_headers(auth_headers);
Ok(GenericProvider::new(protocol, client))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_volcengine() {
let provider = volcengine("test-key");
assert!(provider.is_ok());
}
#[test]
fn test_volcengine_with_config() {
let provider = volcengine_with_config(
"test-key",
Some("https://custom.url"),
Some(60),
None
);
assert!(provider.is_ok());
}
#[test]
fn test_volcengine_protocol_endpoint() {
let protocol = VolcengineProtocol::new("test-key");
let endpoint = protocol.chat_endpoint("https://ark.cn-beijing.volces.com");
assert_eq!(endpoint, "https://ark.cn-beijing.volces.com/api/v3/chat/completions");
}
}