use crate::core::traits::provider::ProviderConfig;
use serde::{Deserialize, Serialize};
pub const DEFAULT_API_BASE: &str = "https://inference.generativeai.{region}.oci.oraclecloud.com";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OciConfig {
pub auth_token: Option<String>,
pub compartment_id: Option<String>,
pub region: Option<String>,
pub api_base: Option<String>,
pub user_ocid: Option<String>,
pub tenancy_ocid: Option<String>,
pub fingerprint: Option<String>,
pub private_key_path: Option<String>,
pub private_key: Option<String>,
#[serde(default = "default_timeout")]
pub timeout: u64,
#[serde(default = "default_max_retries")]
pub max_retries: u32,
#[serde(default)]
pub debug: bool,
}
impl Default for OciConfig {
fn default() -> Self {
Self {
auth_token: None,
compartment_id: None,
region: None,
api_base: None,
user_ocid: None,
tenancy_ocid: None,
fingerprint: None,
private_key_path: None,
private_key: None,
timeout: default_timeout(),
max_retries: default_max_retries(),
debug: false,
}
}
}
impl ProviderConfig for OciConfig {
fn validate(&self) -> Result<(), String> {
let has_token = self.auth_token.is_some() || std::env::var("OCI_AUTH_TOKEN").is_ok();
let has_api_key = (self.user_ocid.is_some() || std::env::var("OCI_USER_OCID").is_ok())
&& (self.fingerprint.is_some() || std::env::var("OCI_FINGERPRINT").is_ok())
&& (self.private_key.is_some()
|| self.private_key_path.is_some()
|| std::env::var("OCI_PRIVATE_KEY").is_ok()
|| std::env::var("OCI_PRIVATE_KEY_PATH").is_ok());
if !has_token && !has_api_key {
return Err(
"OCI authentication not configured. Set OCI_AUTH_TOKEN or configure API key \
authentication with OCI_USER_OCID, OCI_FINGERPRINT, and OCI_PRIVATE_KEY."
.to_string(),
);
}
let has_compartment =
self.compartment_id.is_some() || std::env::var("OCI_COMPARTMENT_ID").is_ok();
if !has_compartment {
return Err(
"OCI compartment ID not configured. Set OCI_COMPARTMENT_ID environment variable \
or pass compartment_id in configuration."
.to_string(),
);
}
let has_region = self.region.is_some() || std::env::var("OCI_REGION").is_ok();
let has_api_base = self.api_base.is_some() || std::env::var("OCI_API_BASE").is_ok();
if !has_region && !has_api_base {
return Err(
"OCI region or API base URL not configured. Set OCI_REGION or OCI_API_BASE \
environment variable."
.to_string(),
);
}
if self.timeout == 0 {
return Err("Timeout must be greater than 0".to_string());
}
Ok(())
}
fn api_key(&self) -> Option<&str> {
self.auth_token.as_deref()
}
fn api_base(&self) -> Option<&str> {
self.api_base.as_deref()
}
fn timeout(&self) -> std::time::Duration {
std::time::Duration::from_secs(self.timeout)
}
fn max_retries(&self) -> u32 {
self.max_retries
}
}
impl OciConfig {
pub fn get_auth_token(&self) -> Option<String> {
self.auth_token
.clone()
.or_else(|| std::env::var("OCI_AUTH_TOKEN").ok())
}
pub fn get_compartment_id(&self) -> Option<String> {
self.compartment_id
.clone()
.or_else(|| std::env::var("OCI_COMPARTMENT_ID").ok())
}
pub fn get_region(&self) -> Option<String> {
self.region
.clone()
.or_else(|| std::env::var("OCI_REGION").ok())
}
pub fn get_api_base(&self) -> Option<String> {
self.api_base
.clone()
.or_else(|| std::env::var("OCI_API_BASE").ok())
.or_else(|| {
self.get_region()
.map(|r| DEFAULT_API_BASE.replace("{region}", &r))
})
}
pub fn build_chat_url(&self) -> String {
let base = self.get_api_base().unwrap_or_else(|| {
"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com".to_string()
});
format!("{}/20231130/actions/chat", base.trim_end_matches('/'))
}
pub fn build_model_url(&self, _model: &str) -> String {
self.build_chat_url()
}
}
fn default_timeout() -> u64 {
60
}
fn default_max_retries() -> u32 {
3
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oci_config_default() {
let config = OciConfig::default();
assert!(config.auth_token.is_none());
assert!(config.compartment_id.is_none());
assert!(config.region.is_none());
assert_eq!(config.timeout, 60);
assert_eq!(config.max_retries, 3);
}
#[test]
fn test_oci_config_build_chat_url() {
let config = OciConfig {
api_base: Some(
"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com".to_string(),
),
..Default::default()
};
let url = config.build_chat_url();
assert!(url.contains("generativeai"));
assert!(url.contains("/actions/chat"));
}
#[test]
fn test_oci_config_provider_config_trait() {
let config = OciConfig {
auth_token: Some("test-token".to_string()),
api_base: Some("https://test.example.com".to_string()),
timeout: 120,
max_retries: 5,
..Default::default()
};
assert_eq!(config.api_key(), Some("test-token"));
assert_eq!(config.api_base(), Some("https://test.example.com"));
assert_eq!(config.timeout(), std::time::Duration::from_secs(120));
assert_eq!(config.max_retries(), 5);
}
}