use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::ops::Not;
use aws_sdk_s3::{
config::{Credentials, Region},
Client as S3Client,
};
use clap::Args;
use figment::{
providers::{Env, Format, Serialized, Toml},
value::magic::RelativePathBuf,
Figment,
};
use jsonwebtoken::{DecodingKey, EncodingKey};
use once_cell::sync::OnceCell;
use redis::{acl::Rule, AsyncCommands};
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use time::Duration;
#[cfg(not(feature = "crossfire-channel"))]
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio_util::sync::CancellationToken;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};
use crate::{
error::Error,
service::worker::{HeartbeatOp, HeartbeatQueue, TaskDispatcher, TaskDispatcherOp},
};
use super::TracingGuard;
pub const DEFAULT_COORDINATOR_ADDR: SocketAddr =
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
#[derive(Deserialize, Serialize, Debug)]
pub struct CoordinatorConfig {
pub(crate) bind: SocketAddr,
pub(crate) db_url: String,
pub(crate) s3_url: String,
pub(crate) s3_access_key: String,
pub(crate) s3_secret_key: String,
#[serde(default = "default_mitosis_region")]
pub(crate) s3_region: String,
#[serde(default)]
pub(crate) s3_force_path_style: bool,
#[serde(default = "default_artifacts_bucket")]
pub(crate) artifacts_bucket: String,
#[serde(default = "default_attachments_bucket")]
pub(crate) attachments_bucket: String,
pub(crate) redis_url: Option<String>,
pub(crate) redis_worker_password: Option<String>,
pub(crate) redis_client_password: Option<String>,
#[serde(default)]
pub(crate) redis_skip_acl_rules: bool,
pub(crate) admin_user: String,
pub(crate) admin_password: String,
pub(crate) access_token_private_path: RelativePathBuf,
pub(crate) access_token_public_path: RelativePathBuf,
#[serde(with = "humantime_serde")]
pub(crate) access_token_expires_in: std::time::Duration,
#[serde(with = "humantime_serde")]
pub(crate) heartbeat_timeout: std::time::Duration,
pub(crate) log_path: Option<RelativePathBuf>,
pub(crate) file_log: bool,
}
fn default_mitosis_region() -> String {
"mitosis".to_string()
}
fn default_artifacts_bucket() -> String {
"mitosis-artifacts".to_string()
}
fn default_attachments_bucket() -> String {
"mitosis-attachments".to_string()
}
#[derive(Args, Debug, Serialize, Default)]
#[command(rename_all = "kebab-case")]
pub struct CoordinatorConfigCli {
#[arg(short, long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub bind: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub config: Option<String>,
#[arg(long = "db")]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub db_url: Option<String>,
#[arg(long = "s3")]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub s3_url: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub s3_access_key: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub s3_secret_key: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub s3_region: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "<&bool>::not")]
pub s3_force_path_style: bool,
#[arg(long = "redis")]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub redis_url: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub redis_worker_password: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub redis_client_password: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "<&bool>::not")]
pub redis_skip_acl_rules: bool,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub admin_user: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub admin_password: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub access_token_private_path: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub access_token_public_path: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub access_token_expires_in: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub heartbeat_timeout: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub log_path: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "<&bool>::not")]
pub file_log: bool,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub artifacts_bucket: Option<String>,
#[arg(long)]
#[serde(skip_serializing_if = "::std::option::Option::is_none")]
pub attachments_bucket: Option<String>,
}
impl Default for CoordinatorConfig {
fn default() -> Self {
Self {
bind: DEFAULT_COORDINATOR_ADDR,
db_url: "postgres://mitosis:mitosis@localhost/mitosis".to_string(),
redis_url: None,
redis_worker_password: None,
redis_client_password: None,
redis_skip_acl_rules: false,
s3_url: "http://localhost:9000".to_string(),
s3_access_key: "mitosis_access".to_string(),
s3_secret_key: "mitosis_secret".to_string(),
s3_region: default_mitosis_region(),
s3_force_path_style: false,
artifacts_bucket: default_artifacts_bucket(),
attachments_bucket: default_attachments_bucket(),
admin_user: "mitosis_admin".to_string(),
admin_password: "mitosis_admin".to_string(),
access_token_private_path: "private.pem".to_string().into(),
access_token_public_path: "public.pem".to_string().into(),
access_token_expires_in: std::time::Duration::from_secs(60 * 60 * 24 * 7),
heartbeat_timeout: std::time::Duration::from_secs(600),
log_path: None,
file_log: false,
}
}
}
impl CoordinatorConfig {
pub fn new(cli: &CoordinatorConfigCli) -> crate::error::Result<Self> {
let global_config = dirs::config_dir().map(|mut p| {
p.push("mitosis");
p.push("config.toml");
p
});
let mut figment = Figment::new().merge(Serialized::from(Self::default(), "coordinator"));
if let Some(global_config) = global_config {
if global_config.exists() {
figment = figment.merge(Toml::file(global_config).nested());
}
}
figment = figment
.merge(Toml::file(cli.config.as_deref().unwrap_or("config.toml")).nested())
.merge(Env::prefixed("MITO_").profile("coordinator"))
.merge(Serialized::from(cli, "coordinator"))
.select("coordinator");
Ok(figment.extract()?)
}
pub fn build_worker_task_queue(
&self,
cancel_token: CancellationToken,
#[cfg(not(feature = "crossfire-channel"))] rx: UnboundedReceiver<TaskDispatcherOp>,
#[cfg(feature = "crossfire-channel")] rx: crossfire::AsyncRx<TaskDispatcherOp>,
) -> TaskDispatcher {
TaskDispatcher::new(cancel_token, rx)
}
pub fn build_worker_heartbeat_queue(
&self,
cancel_token: CancellationToken,
pool: InfraPool,
#[cfg(not(feature = "crossfire-channel"))] rx: UnboundedReceiver<HeartbeatOp>,
#[cfg(feature = "crossfire-channel")] rx: crossfire::AsyncRx<HeartbeatOp>,
) -> HeartbeatQueue {
HeartbeatQueue::new(cancel_token, self.heartbeat_timeout, pool, rx)
}
pub async fn build_redis_connection_info(
&self,
) -> crate::error::Result<Option<RedisConnectionInfo>> {
match self.redis_url {
Some(ref redis_url) => {
let client = redis::Client::open(redis_url.clone())?;
let mut conn = client.get_multiplexed_tokio_connection().await?;
let worker_pass = {
if let Some(worker_pass) = &self.redis_worker_password {
worker_pass.clone()
} else {
conn.acl_genpass().await?
}
};
let client_pass = {
if let Some(client_pass) = &self.redis_client_password {
client_pass.clone()
} else {
conn.acl_genpass().await?
}
};
if !self.redis_skip_acl_rules {
let rules = [Rule::Reset];
let _: String = conn.acl_setuser_rules("mitosis_worker", &rules).await?;
let _: String = conn.acl_setuser_rules("mitosis_client", &rules).await?;
let rules = [
Rule::On,
Rule::AddPass(worker_pass.clone()),
Rule::Pattern("task:*".to_string()),
Rule::Other("&task:*".to_string()),
Rule::AddCategory("read".to_string()),
Rule::AddCategory("write".to_string()),
Rule::AddCategory("connection".to_string()),
Rule::AddCategory("pubsub".to_string()),
Rule::RemoveCategory("dangerous".to_string()),
];
let _: String = conn.acl_setuser_rules("mitosis_worker", &rules).await?;
let rules = [
Rule::On,
Rule::AddPass(client_pass.clone()),
Rule::Pattern("task:*".to_string()),
Rule::Other("&task:*".to_string()),
Rule::AddCategory("read".to_string()),
Rule::AddCategory("pubsub".to_string()),
Rule::AddCategory("connection".to_string()),
Rule::RemoveCategory("dangerous".to_string()),
];
let _: String = conn.acl_setuser_rules("mitosis_client", &rules).await?;
}
let conn_info = client.get_connection_info();
Ok(Some(RedisConnectionInfo::new(
conn_info.addr.clone(),
worker_pass,
client_pass,
)))
}
None => Ok(None),
}
}
pub async fn build_infra_pool(
&self,
#[cfg(not(feature = "crossfire-channel"))] worker_task_queue_tx: UnboundedSender<
TaskDispatcherOp,
>,
#[cfg(feature = "crossfire-channel")] worker_task_queue_tx: crossfire::MTx<
TaskDispatcherOp,
>,
#[cfg(not(feature = "crossfire-channel"))] worker_heartbeat_queue_tx: UnboundedSender<
HeartbeatOp,
>,
#[cfg(feature = "crossfire-channel")] worker_heartbeat_queue_tx: crossfire::MTx<
HeartbeatOp,
>,
) -> crate::error::Result<InfraPool> {
let db = sea_orm::Database::connect(&self.db_url).await?;
let credential = Credentials::new(
&self.s3_access_key,
&self.s3_secret_key,
None,
None,
"mitosis",
);
let config: aws_sdk_s3::Config = aws_sdk_s3::Config::builder()
.credentials_provider(credential)
.endpoint_url(self.s3_url.clone())
.region(Region::new(self.s3_region.clone()))
.force_path_style(self.s3_force_path_style)
.build();
let s3 = S3Client::from_conf(config);
Ok(InfraPool {
db,
s3,
artifacts_bucket: self.artifacts_bucket.clone(),
attachments_bucket: self.attachments_bucket.clone(),
worker_task_queue_tx,
worker_heartbeat_queue_tx,
})
}
pub fn build_admin_user(&self) -> crate::error::Result<InitAdminUser> {
if self.admin_password.len() > 255 || self.admin_user.len() > 255 {
Err(figment::Error::from("username or password too long").into())
} else {
Ok(InitAdminUser {
username: self.admin_user.clone(),
password: self.admin_password.clone(),
})
}
}
pub fn build_server_config(&self) -> crate::error::Result<ServerConfig> {
Ok(ServerConfig {
bind: self.bind,
token_expires_in: Duration::try_from(self.access_token_expires_in)
.map_err(|e| figment::Error::from(e.to_string()))?,
})
}
pub async fn build_jwt_encoding_key(&self) -> crate::error::Result<EncodingKey> {
let private_key = tokio::fs::read(&self.access_token_private_path.relative()).await?;
Ok(EncodingKey::from_ed_pem(&private_key)?)
}
pub async fn build_jwt_decoding_key(&self) -> crate::error::Result<DecodingKey> {
let public_key = tokio::fs::read(&self.access_token_public_path.relative()).await?;
Ok(DecodingKey::from_ed_pem(&public_key)?)
}
pub fn setup_tracing_subscriber(&self) -> crate::error::Result<TracingGuard> {
if self.file_log {
let file_logger = self
.log_path
.as_ref()
.and_then(|p| {
let path = p.relative();
let dir = path.parent();
let file_name = path.file_name();
match (dir, file_name) {
(Some(dir), Some(file_name)) => {
Some(tracing_appender::rolling::never(dir, file_name))
}
_ => None,
}
})
.or_else(|| {
dirs::cache_dir()
.map(|mut p| {
p.push("mitosis");
p.push("coordinator");
p
})
.map(|dir| {
tracing_appender::rolling::daily(dir, format!("{}.log", self.bind))
})
})
.ok_or(Error::ConfigError(Box::new(figment::Error::from(
"log path not valid and cache directory not found",
))))?;
let (non_blocking, guard) = tracing_appender::non_blocking(file_logger);
let env_filter = tracing_subscriber::EnvFilter::try_from_env("MITO_FILE_LOG_LEVEL")
.unwrap_or_else(|_| "netmito=info".into());
let coordinator_guard = tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer().with_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "netmito=info".into()),
),
)
.with(
tracing_subscriber::fmt::layer()
.with_writer(non_blocking)
.with_filter(env_filter),
)
.set_default();
Ok(TracingGuard {
subscriber_guard: Some(coordinator_guard),
file_guard: Some(guard),
})
} else {
let coordinator_guard = tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer().with_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "netmito=info".into()),
),
)
.set_default();
Ok(TracingGuard {
subscriber_guard: Some(coordinator_guard),
file_guard: None,
})
}
}
}
#[derive(Debug, Clone)]
pub struct InfraPool {
pub db: DatabaseConnection,
pub s3: S3Client,
pub artifacts_bucket: String,
pub attachments_bucket: String,
#[cfg(not(feature = "crossfire-channel"))]
pub worker_task_queue_tx: UnboundedSender<TaskDispatcherOp>,
#[cfg(feature = "crossfire-channel")]
pub worker_task_queue_tx: crossfire::MTx<TaskDispatcherOp>,
#[cfg(not(feature = "crossfire-channel"))]
pub worker_heartbeat_queue_tx: UnboundedSender<HeartbeatOp>,
#[cfg(feature = "crossfire-channel")]
pub worker_heartbeat_queue_tx: crossfire::MTx<HeartbeatOp>,
}
#[derive(Debug)]
pub struct ServerConfig {
pub bind: SocketAddr,
pub token_expires_in: Duration,
}
#[derive(Debug)]
pub struct InitAdminUser {
pub username: String,
pub password: String,
}
#[derive(Debug)]
pub struct RedisConnectionInfo {
addr: redis::ConnectionAddr,
worker_pass: String,
client_pass: String,
}
impl RedisConnectionInfo {
pub fn new(addr: redis::ConnectionAddr, worker_pass: String, client_pass: String) -> Self {
Self {
addr,
worker_pass,
client_pass,
}
}
pub fn worker_url(&self) -> String {
format!(
"redis://mitosis_worker:{}@{}/?protocol=resp3",
self.worker_pass, self.addr
)
}
pub fn client_url(&self) -> String {
format!(
"redis://mitosis_client:{}@{}/?protocol=resp3",
self.client_pass, self.addr
)
}
}
pub(crate) static SERVER_CONFIG: OnceCell<ServerConfig> = OnceCell::new();
pub(crate) static INIT_ADMIN_USER: OnceCell<InitAdminUser> = OnceCell::new();
pub(crate) static ENCODING_KEY: OnceCell<EncodingKey> = OnceCell::new();
pub(crate) static DECODING_KEY: OnceCell<DecodingKey> = OnceCell::new();
pub(crate) static REDIS_CONNECTION_INFO: OnceCell<RedisConnectionInfo> = OnceCell::new();
pub(crate) static SHUTDOWN_SECRET: OnceCell<String> = OnceCell::new();