use std::sync::Once;
use genies_core::jwt::*;
use serde::{Deserialize, Serialize};
use genies_config::app_config::ApplicationConfig;
use genies_cache::cache_service::CacheService;
use rbatis::RBatis;
#[derive(Serialize, Deserialize, Clone, Debug, Default)]
pub struct RemoteToken {
pub access_token: String,
}
impl RemoteToken {
pub fn new() -> Self {
let config = ApplicationConfig::from_sources("./application.yml").unwrap();
let url = config.keycloak_auth_server_url.clone();
let realm = config.keycloak_realm.clone();
let resource = config.keycloak_resource.clone();
let secret = config.keycloak_credentials_secret.clone();
Self {
access_token: std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
get_temp_access_token(&url, &realm, &resource, &secret)
.await
.unwrap_or_else(|e| {
log::error!("Failed to get temp access token: {}", e);
String::new()
})
})
}).join().unwrap(),
}
}
}
pub struct ApplicationContext {
pub config: ApplicationConfig,
pub rbatis: RBatis,
pub cache_service: CacheService,
pub redis_save_service: CacheService,
pub keycloak_keys: Keys,
db_init_once: Once, }
async fn try_register_slot(cache: &CacheService, server_name: &str) -> Option<(i32, String)> {
for i in 0..1024i32 {
let key = format!("snowflake:slot:{}:{}", server_name, i);
match cache.set_string_ex_nx(&key, "1", Some(std::time::Duration::from_secs(3600))).await {
Ok(true) => return Some((i, key)),
_ => continue,
}
}
None
}
fn resolve_worker_id(config: &ApplicationConfig, cache: &CacheService) -> i32 {
if config.cache_type == "redis" {
let result = if let Ok(handle) = tokio::runtime::Handle::try_current() {
tokio::task::block_in_place(|| {
handle.block_on(try_register_slot(cache, &config.server_name))
})
} else {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(try_register_slot(cache, &config.server_name))
};
if let Some((id, key)) = result {
log::info!("Registered snowflake worker_id via Redis slot: {}", id);
std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(30 * 60)).await;
match crate::CONTEXT.cache_service.set_string_ex(&key, "1", Some(std::time::Duration::from_secs(3600))).await {
Ok(_) => log::debug!("Renewed snowflake slot TTL: {}", key),
Err(e) => log::warn!("Failed to renew snowflake slot: {}", e),
}
}
});
});
return id;
}
}
if let Ok(hostname) = std::env::var("HOSTNAME") {
if let Some(id) = hostname.rsplit('-').next().and_then(|s| s.parse::<i32>().ok()) {
let worker_id = id % 1024;
log::info!("Using K8s pod ordinal as machine_id: {}", worker_id);
return worker_id;
}
}
if let Some(id) = config.machine_id {
log::info!("Using configured machine_id: {}", id);
return id as i32;
}
log::warn!("Using fallback machine_id: 1");
1
}
fn create_db_driver(url: &str) -> Box<dyn rbdc::db::Driver> {
let scheme = url.split("://").next().unwrap_or("");
match scheme {
#[cfg(feature = "mysql")]
"mysql" => Box::new(rbdc_mysql::driver::MysqlDriver {}),
#[cfg(feature = "postgres")]
"postgres" | "postgresql" => Box::new(rbdc_pg::driver::PgDriver {}),
#[cfg(feature = "sqlite")]
"sqlite" => Box::new(rbdc_sqlite::driver::SqliteDriver {}),
#[cfg(feature = "mssql")]
"mssql" | "sqlserver" => Box::new(rbdc_mssql::driver::MssqlDriver {}),
#[cfg(feature = "oracle")]
"oracle" => Box::new(rbdc_oracle::driver::OracleDriver {}),
#[cfg(feature = "tdengine")]
"taos" | "taos+ws" => Box::new(rbdc_tdengine::driver::TaosDriver {}),
_ => panic!("Unsupported database scheme '{}'. Check database_url or enable the corresponding feature flag.", scheme),
}
}
impl ApplicationContext {
pub async fn init_database(&self) {
self.db_init_once.call_once(|| {
let driver = create_db_driver(&self.config.database_url);
log::info!("rbatis database init ({})...", self.config.database_url);
let _ = self.rbatis.init(driver, &self.config.database_url).unwrap();
let _ = self.rbatis.get_pool().unwrap().set_max_open_conns(self.config.max_connections as u64);
let _ = self.rbatis.get_pool().unwrap().set_max_idle_conns(self.config.wait_timeout as u64);
let _ = self.rbatis.get_pool().unwrap().set_conn_max_lifetime(Some(std::time::Duration::from_secs(self.config.max_lifetime)));
});
let _ = self.rbatis.get_pool().unwrap().get().await;
log::info!("rbatis database init success! pool state = {:?}", self.rbatis.get_pool().unwrap().state().await);
}
#[deprecated(note = "Use init_database() instead")]
pub async fn init_mysql(&self) {
self.init_database().await;
}
pub fn new() -> Self {
let config = ApplicationConfig::from_sources("./application.yml").unwrap();
log::debug!("config = {:?}", config);
let auth_url = config.keycloak_auth_server_url.clone();
let auth_realm = config.keycloak_realm.clone();
let cache_service = CacheService::new(&config);
let redis_save_service = CacheService::new_saved(&config);
let machine_id = resolve_worker_id(&config, &cache_service);
genies_core::id_gen::init(machine_id, 1);
ApplicationContext {
keycloak_keys: if config.auth_mode == "local" {
log::info!("auth_mode=local, skipping keycloak key initialization");
Keys { keys: vec![] }
} else {
std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
if auth_url.is_empty() || auth_realm.is_empty() {
log::warn!(
"Keycloak auth_server_url or realm is empty, \
skipping keycloak key initialization."
);
Ok(Keys { keys: vec![] })
} else {
log::info!("Initializing keycloak keys from: {}", auth_url);
get_keycloak_keys(&auth_url, &auth_realm).await
}
})
}).join().unwrap()
.unwrap_or_else(|e| {
log::warn!("Failed to get keycloak keys: {}, token verification will use local JWT", e);
Keys { keys: vec![] }
})
},
rbatis: RBatis::new(),
cache_service,
redis_save_service,
config,
db_init_once: Once::new(),
}
}
}
impl Default for ApplicationContext {
fn default() -> Self {
Self::new()
}
}