use std::{env::var, pin::Pin, str::FromStr, time::Duration};
use async_openai::{config::OpenAIConfig, Client};
use sqlx::{
postgres::{PgListener, PgPoolOptions},
Pool, Postgres,
};
use tokio_stream::{Stream, StreamExt};
use uuid::Uuid;
use crate::{
errors::sdk::SDKError,
organization::operations::{
Organization, OrganizationCrudOperations, OrganizationInitializationInput, SetOrganizationInputBuilder,
GLOBAL_ORGANIZATION_SETTINGS_NAME,
},
resources::changes::change::{ChangeOperation, ChangeResourceType, ListenEvent}, };
const VERSION: Option<&str> = option_env!("CARGO_PKG_VERSION");
#[derive(Clone)]
pub struct SDKConfig {
pub database_url: String,
pub llm_api_key: String,
pub llm_model_name: String,
}
impl SDKConfig {
pub fn from_env() -> SDKConfig {
let database_url = var("DATABASE_URL").unwrap();
let llm_api_key = var("OPENAI_API_KEY").unwrap();
let llm_model_name = var("OPENAI_MODEL_NAME").unwrap_or("gpt-3.5-turbo".to_string());
SDKConfig {
database_url,
llm_api_key,
llm_model_name,
}
}
}
#[derive(Clone)]
pub struct SDKEngine {
pub config: SDKConfig,
pub db_pool: Box<Pool<Postgres>>,
pub llm_client: Box<Client<OpenAIConfig>>,
}
impl SDKEngine {
pub async fn new(config: SDKConfig) -> Result<SDKEngine, SDKError> {
let pool = PgPoolOptions::new()
.max_connections(10)
.acquire_timeout(Duration::from_secs(60))
.connect(config.database_url.as_str())
.await?;
let llm_config = OpenAIConfig::default().with_api_key(config.llm_api_key.clone());
let llm_client = Box::new(Client::with_config(llm_config));
let db_pool = Box::new(pool);
let engine = SDKEngine {
config,
db_pool,
llm_client,
};
Ok(engine)
}
pub async fn migrate(&self) -> Result<(), SDKError> {
sqlx::migrate!().run(self.db_pool.as_ref()).await?;
Ok(())
}
pub fn version(&self) -> Result<String, SDKError> {
match VERSION {
Some(version) => Ok(version.to_string()),
None => Err(SDKError::VersionNotFound),
}
}
pub async fn initialize_organization(
&self,
owner_id: Uuid,
value: OrganizationInitializationInput,
) -> Result<Organization, SDKError> {
let org_serialized = serde_json::to_string(&value)?;
let org = self
.set_organization_setting(
SetOrganizationInputBuilder::default()
.owner_id(owner_id)
.name(GLOBAL_ORGANIZATION_SETTINGS_NAME.to_string())
.value(org_serialized)
.build()
.unwrap(),
)
.await?;
Ok(org.into())
}
pub async fn listen(
&self,
resource: ChangeResourceType,
) -> Result<Pin<Box<dyn Stream<Item = Result<ListenEvent, SDKError>> + Send>>, SDKError> {
let mut db_listener = PgListener::connect_with(&self.db_pool).await?;
db_listener
.listen(format!("{}_table_update", resource.to_string().to_lowercase()).as_str())
.await?;
let mapped_stream = db_listener.into_stream().map(|x| match x {
Ok(not) => {
let mut payload = not.payload().split_whitespace();
let resource = ChangeResourceType::from_str(payload.next().unwrap()).unwrap();
let operation = ChangeOperation::from_str(payload.next().unwrap()).unwrap();
let row_id = payload.next().map(|a| a.parse::<Uuid>().unwrap()).unwrap();
Ok(ListenEvent {
resource,
operation,
row_id,
})
}
Err(e) => Err(SDKError::from(e)),
});
Ok(Box::pin(mapped_stream))
}
}