use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
pub struct SAPGenerativeAIProvider {
#[allow(dead_code)]
api_key: String,
#[allow(dead_code)]
tenant_id: String,
}
impl SAPGenerativeAIProvider {
pub fn new(api_key: &str, tenant_id: &str) -> Self {
Self {
api_key: api_key.to_string(),
tenant_id: tenant_id.to_string(),
}
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("SAP_API_KEY").map_err(|_| {
Error::Configuration("SAP_API_KEY environment variable not set".to_string())
})?;
let tenant_id = std::env::var("SAP_TENANT_ID").map_err(|_| {
Error::Configuration("SAP_TENANT_ID environment variable not set".to_string())
})?;
Ok(Self::new(&api_key, &tenant_id))
}
pub async fn list_models(&self) -> Result<Vec<String>> {
Ok(vec![
"gpt-4".to_string(),
"gpt-3.5-turbo".to_string(),
"llama-2-7b".to_string(),
"llama-2-13b".to_string(),
"llama-2-70b".to_string(),
"gemini-pro".to_string(),
"claude-2".to_string(),
"ai21-j2-jumbo".to_string(),
])
}
pub fn get_model_info(model: &str) -> Option<SAPModelInfo> {
match model {
"gpt-4" => Some(SAPModelInfo {
name: "gpt-4".to_string(),
provider: "OpenAI".to_string(),
context_window: 8192,
supports_vision: true,
supports_function_call: true,
max_output_tokens: 4096,
fine_tuning_available: false,
family: "GPT".to_string(),
}),
"gpt-3.5-turbo" => Some(SAPModelInfo {
name: "gpt-3.5-turbo".to_string(),
provider: "OpenAI".to_string(),
context_window: 4096,
supports_vision: false,
supports_function_call: true,
max_output_tokens: 2048,
fine_tuning_available: true,
family: "GPT".to_string(),
}),
"llama-2-7b" => Some(SAPModelInfo {
name: "llama-2-7b".to_string(),
provider: "Meta".to_string(),
context_window: 4096,
supports_vision: false,
supports_function_call: false,
max_output_tokens: 2048,
fine_tuning_available: true,
family: "LLaMA".to_string(),
}),
"llama-2-13b" => Some(SAPModelInfo {
name: "llama-2-13b".to_string(),
provider: "Meta".to_string(),
context_window: 4096,
supports_vision: false,
supports_function_call: false,
max_output_tokens: 4096,
fine_tuning_available: true,
family: "LLaMA".to_string(),
}),
"llama-2-70b" => Some(SAPModelInfo {
name: "llama-2-70b".to_string(),
provider: "Meta".to_string(),
context_window: 4096,
supports_vision: false,
supports_function_call: false,
max_output_tokens: 4096,
fine_tuning_available: false,
family: "LLaMA".to_string(),
}),
"gemini-pro" => Some(SAPModelInfo {
name: "gemini-pro".to_string(),
provider: "Google".to_string(),
context_window: 32768,
supports_vision: true,
supports_function_call: true,
max_output_tokens: 8192,
fine_tuning_available: false,
family: "Gemini".to_string(),
}),
"claude-2" => Some(SAPModelInfo {
name: "claude-2".to_string(),
provider: "Anthropic".to_string(),
context_window: 100000,
supports_vision: false,
supports_function_call: false,
max_output_tokens: 4096,
fine_tuning_available: false,
family: "Claude".to_string(),
}),
"ai21-j2-jumbo" => Some(SAPModelInfo {
name: "ai21-j2-jumbo".to_string(),
provider: "AI21".to_string(),
context_window: 8191,
supports_vision: false,
supports_function_call: false,
max_output_tokens: 4096,
fine_tuning_available: true,
family: "J2".to_string(),
}),
_ => None,
}
}
pub fn get_consumption_plan(plan: &str) -> Option<SAPConsumptionPlan> {
match plan {
"free" => Some(SAPConsumptionPlan {
plan_name: "Free".to_string(),
monthly_tokens: 100000,
price_per_1m_tokens: 0.0,
support_tier: "Community".to_string(),
api_calls_per_minute: 10,
concurrent_requests: 1,
}),
"starter" => Some(SAPConsumptionPlan {
plan_name: "Starter".to_string(),
monthly_tokens: 10000000,
price_per_1m_tokens: 1.5,
support_tier: "Standard".to_string(),
api_calls_per_minute: 100,
concurrent_requests: 10,
}),
"professional" => Some(SAPConsumptionPlan {
plan_name: "Professional".to_string(),
monthly_tokens: 1000000000,
price_per_1m_tokens: 0.8,
support_tier: "Priority".to_string(),
api_calls_per_minute: 1000,
concurrent_requests: 100,
}),
"enterprise" => Some(SAPConsumptionPlan {
plan_name: "Enterprise".to_string(),
monthly_tokens: 10000000000,
price_per_1m_tokens: 0.5,
support_tier: "Enterprise".to_string(),
api_calls_per_minute: 10000,
concurrent_requests: 1000,
}),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SAPModelInfo {
pub name: String,
pub provider: String,
pub context_window: u32,
pub supports_vision: bool,
pub supports_function_call: bool,
pub max_output_tokens: u32,
pub fine_tuning_available: bool,
pub family: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SAPConsumptionPlan {
pub plan_name: String,
pub monthly_tokens: u64,
pub price_per_1m_tokens: f64,
pub support_tier: String,
pub api_calls_per_minute: u32,
pub concurrent_requests: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum IntegrationType {
#[default]
DirectAPI,
SAPBtp,
OnPremises,
Hybrid,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sap_provider_creation() {
let provider = SAPGenerativeAIProvider::new("test-key", "test-tenant");
assert_eq!(provider.api_key, "test-key");
assert_eq!(provider.tenant_id, "test-tenant");
}
#[tokio::test]
async fn test_list_models() {
let provider = SAPGenerativeAIProvider::new("test-key", "test-tenant");
let models = provider.list_models().await.unwrap();
assert!(!models.is_empty());
assert!(models.contains(&"gpt-4".to_string()));
assert!(models.contains(&"llama-2-70b".to_string()));
assert!(models.contains(&"claude-2".to_string()));
}
#[test]
fn test_get_model_info() {
let info = SAPGenerativeAIProvider::get_model_info("gpt-4").unwrap();
assert_eq!(info.name, "gpt-4");
assert_eq!(info.provider, "OpenAI");
assert!(info.supports_vision);
assert!(info.supports_function_call);
}
#[test]
fn test_llama_model_info() {
let info = SAPGenerativeAIProvider::get_model_info("llama-2-70b").unwrap();
assert_eq!(info.family, "LLaMA");
assert_eq!(info.provider, "Meta");
assert!(!info.fine_tuning_available); }
#[test]
fn test_claude_model_info() {
let info = SAPGenerativeAIProvider::get_model_info("claude-2").unwrap();
assert_eq!(info.provider, "Anthropic");
assert_eq!(info.context_window, 100000);
assert!(!info.supports_vision);
}
#[test]
fn test_gemini_model_info() {
let info = SAPGenerativeAIProvider::get_model_info("gemini-pro").unwrap();
assert_eq!(info.provider, "Google");
assert!(info.supports_vision);
assert_eq!(info.context_window, 32768);
}
#[test]
fn test_model_info_invalid() {
let info = SAPGenerativeAIProvider::get_model_info("invalid-model");
assert!(info.is_none());
}
#[test]
fn test_get_consumption_plan() {
let plan = SAPGenerativeAIProvider::get_consumption_plan("starter").unwrap();
assert_eq!(plan.plan_name, "Starter");
assert!(plan.monthly_tokens > 0);
assert_eq!(plan.api_calls_per_minute, 100);
}
#[test]
fn test_free_plan() {
let plan = SAPGenerativeAIProvider::get_consumption_plan("free").unwrap();
assert_eq!(plan.plan_name, "Free");
assert_eq!(plan.price_per_1m_tokens, 0.0);
assert_eq!(plan.support_tier, "Community");
}
#[test]
fn test_enterprise_plan() {
let plan = SAPGenerativeAIProvider::get_consumption_plan("enterprise").unwrap();
assert_eq!(plan.plan_name, "Enterprise");
assert!(plan.monthly_tokens > 1000000000);
assert_eq!(plan.support_tier, "Enterprise");
}
#[test]
fn test_consumption_plan_invalid() {
let plan = SAPGenerativeAIProvider::get_consumption_plan("invalid-plan");
assert!(plan.is_none());
}
#[test]
fn test_plan_pricing_hierarchy() {
let free = SAPGenerativeAIProvider::get_consumption_plan("free").unwrap();
let starter = SAPGenerativeAIProvider::get_consumption_plan("starter").unwrap();
let professional = SAPGenerativeAIProvider::get_consumption_plan("professional").unwrap();
let enterprise = SAPGenerativeAIProvider::get_consumption_plan("enterprise").unwrap();
assert!(free.monthly_tokens < starter.monthly_tokens);
assert!(starter.monthly_tokens < professional.monthly_tokens);
assert!(professional.monthly_tokens < enterprise.monthly_tokens);
assert!(starter.price_per_1m_tokens > professional.price_per_1m_tokens);
}
#[test]
fn test_integration_type_default() {
assert_eq!(IntegrationType::default(), IntegrationType::DirectAPI);
}
#[test]
fn test_model_count() {
let models = vec![
"gpt-4",
"gpt-3.5-turbo",
"llama-2-7b",
"llama-2-13b",
"llama-2-70b",
"gemini-pro",
"claude-2",
"ai21-j2-jumbo",
];
assert_eq!(models.len(), 8);
for model in models {
assert!(SAPGenerativeAIProvider::get_model_info(model).is_some());
}
}
}