use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use crate::provider::{Provider, ProviderConfig};
use crate::types::{
CompletionRequest, CompletionResponse, ContentBlock, ContentDelta, Role, StopReason,
StreamChunk, StreamEventType, Usage,
};
const NLP_CLOUD_API_BASE: &str = "https://api.nlpcloud.io/v1/gpu";
pub struct NlpCloudProvider {
config: ProviderConfig,
client: Client,
}
impl NlpCloudProvider {
pub fn new(config: ProviderConfig) -> Result<Self> {
let mut headers = reqwest::header::HeaderMap::new();
if let Some(ref key) = config.api_key {
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Token {}", key)
.parse()
.map_err(|_| Error::config("Invalid API key format"))?,
);
}
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
let client = Client::builder()
.timeout(config.timeout)
.default_headers(headers)
.build()?;
Ok(Self { config, client })
}
pub fn from_env() -> Result<Self> {
let config = ProviderConfig::from_env("NLP_CLOUD_API_KEY");
Self::new(config)
}
pub fn with_api_key(api_key: impl Into<String>) -> Result<Self> {
let config = ProviderConfig::new(api_key);
Self::new(config)
}
fn api_url(&self, model: &str) -> String {
format!(
"{}/{}/chatbot",
self.config
.base_url
.as_deref()
.unwrap_or(NLP_CLOUD_API_BASE),
model
)
}
fn convert_request(&self, request: &CompletionRequest) -> NlpCloudRequest {
let mut history = Vec::new();
let mut current_input = String::new();
for msg in &request.messages {
let text = msg
.content
.iter()
.filter_map(|block| {
if let ContentBlock::Text { text } = block {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
match msg.role {
Role::User => current_input = text,
Role::Assistant => {
if !current_input.is_empty() {
history.push(NlpCloudHistoryItem {
input: current_input.clone(),
response: text,
});
current_input.clear();
}
}
Role::System => {
current_input = format!("{}\n\n{}", text, current_input);
}
}
}
if let Some(ref system) = request.system {
if history.is_empty() && !current_input.contains(system) {
current_input = format!("{}\n\n{}", system, current_input);
}
}
NlpCloudRequest {
input: current_input,
history: if history.is_empty() {
None
} else {
Some(history)
},
max_length: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
}
}
fn convert_response(&self, model: &str, response: NlpCloudResponse) -> CompletionResponse {
CompletionResponse {
id: uuid::Uuid::new_v4().to_string(),
model: model.to_string(),
content: vec![ContentBlock::Text {
text: response.response,
}],
stop_reason: StopReason::EndTurn,
usage: Usage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}, }
}
}
#[async_trait]
impl Provider for NlpCloudProvider {
fn name(&self) -> &str {
"nlp-cloud"
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let model = request.model.clone();
let api_request = self.convert_request(&request);
let response = self
.client
.post(self.api_url(&model))
.json(&api_request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(Error::server(
status.as_u16(),
format!("NLP Cloud API error {}: {}", status, error_text),
));
}
let api_response: NlpCloudResponse = response.json().await?;
Ok(self.convert_response(&model, api_response))
}
async fn complete_stream(
&self,
request: CompletionRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
let response = self.complete(request).await?;
let stream = async_stream::try_stream! {
yield StreamChunk {
event_type: StreamEventType::ContentBlockStart,
index: Some(0),
delta: None,
stop_reason: None,
usage: None,
};
for block in response.content {
if let ContentBlock::Text { text } = block {
yield StreamChunk {
event_type: StreamEventType::ContentBlockDelta,
index: Some(0),
delta: Some(ContentDelta::Text { text }),
stop_reason: None,
usage: None,
};
}
}
yield StreamChunk {
event_type: StreamEventType::MessageStop,
index: None,
delta: None,
stop_reason: Some(StopReason::EndTurn),
usage: Some(response.usage),
};
};
Ok(Box::pin(stream))
}
}
#[derive(Debug, Serialize)]
struct NlpCloudRequest {
input: String,
#[serde(skip_serializing_if = "Option::is_none")]
history: Option<Vec<NlpCloudHistoryItem>>,
#[serde(skip_serializing_if = "Option::is_none")]
max_length: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
}
#[derive(Debug, Serialize)]
struct NlpCloudHistoryItem {
input: String,
response: String,
}
#[derive(Debug, Deserialize)]
struct NlpCloudResponse {
response: String,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Message;
#[test]
fn test_provider_creation() {
let provider = NlpCloudProvider::new(ProviderConfig::new("test-key")).unwrap();
assert_eq!(provider.name(), "nlp-cloud");
}
#[test]
fn test_provider_with_api_key() {
let provider = NlpCloudProvider::with_api_key("test-key").unwrap();
assert_eq!(provider.name(), "nlp-cloud");
}
#[test]
fn test_api_url() {
let provider = NlpCloudProvider::new(ProviderConfig::new("test-key")).unwrap();
assert_eq!(
provider.api_url("chatdolphin"),
"https://api.nlpcloud.io/v1/gpu/chatdolphin/chatbot"
);
}
#[test]
fn test_api_url_custom_base() {
let mut config = ProviderConfig::new("test-key");
config.base_url = Some("https://custom.nlpcloud.io".to_string());
let provider = NlpCloudProvider::new(config).unwrap();
assert_eq!(
provider.api_url("chatdolphin"),
"https://custom.nlpcloud.io/chatdolphin/chatbot"
);
}
#[test]
fn test_convert_request() {
let provider = NlpCloudProvider::new(ProviderConfig::new("test-key")).unwrap();
let mut request = CompletionRequest::new(
"chatdolphin",
vec![Message {
role: Role::User,
content: vec![ContentBlock::Text {
text: "Hello".to_string(),
}],
}],
);
request.system = Some("You are helpful".to_string());
request.max_tokens = Some(100);
let api_request = provider.convert_request(&request);
assert!(api_request.input.contains("You are helpful"));
assert!(api_request.input.contains("Hello"));
assert_eq!(api_request.max_length, Some(100));
}
#[test]
fn test_convert_request_with_history() {
let provider = NlpCloudProvider::new(ProviderConfig::new("test-key")).unwrap();
let request = CompletionRequest::new(
"chatdolphin",
vec![
Message::user("Hello"),
Message::assistant("Hi there!"),
Message::user("How are you?"),
],
);
let api_request = provider.convert_request(&request);
assert!(api_request.history.is_some());
let history = api_request.history.unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].input, "Hello");
assert_eq!(history[0].response, "Hi there!");
assert_eq!(api_request.input, "How are you?");
}
#[test]
fn test_convert_response() {
let provider = NlpCloudProvider::new(ProviderConfig::new("test-key")).unwrap();
let response = NlpCloudResponse {
response: "Hello! I'm doing well.".to_string(),
};
let result = provider.convert_response("chatdolphin", response);
assert_eq!(result.model, "chatdolphin");
assert_eq!(result.content.len(), 1);
match &result.content[0] {
ContentBlock::Text { text } => {
assert_eq!(text, "Hello! I'm doing well.");
}
other => {
panic!("Expected text content block, got {:?}", other);
}
}
assert!(matches!(result.stop_reason, StopReason::EndTurn));
}
#[test]
fn test_request_serialization() {
let request = NlpCloudRequest {
input: "Hello".to_string(),
history: Some(vec![NlpCloudHistoryItem {
input: "Hi".to_string(),
response: "Hello!".to_string(),
}]),
max_length: Some(100),
temperature: Some(0.7),
top_p: Some(0.9),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("Hello"));
assert!(json.contains("history"));
assert!(json.contains("max_length"));
}
#[test]
fn test_response_deserialization() {
let json = r#"{"response": "This is a test response"}"#;
let response: NlpCloudResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.response, "This is a test response");
}
}