use async_sqlite::ClientBuilder;
use clap::Parser;
use async_sqlite::rusqlite::{params, Connection, OpenFlags, Result};
use serde::Serialize;
use shvbroker::config::{AccessRule, Mount, Password, ProfileValue, Role, User};
use std::collections::BTreeMap;
use serde::Deserialize;
use shvbroker::config::{BrokerConfig, BrokerConnectionConfig, ConnectionMountSettings, Listen, AzureConfig as BrokerAzureConfig};
use shvproto::RpcValue;
use shvrpc::client::ClientConfig;
use url::Url;
use std::path::Path;
use std::time::Duration;
fn load_users(conn: &Connection) -> Result<BTreeMap<String, User>> {
let mut stmt = conn.prepare("SELECT name, password, passwordFormat, roles FROM acl_users")?;
let mut rows = stmt.query([])?;
let mut users = BTreeMap::new();
while let Some(row) = rows.next()? {
let name: String = row.get("name")?;
let password: String = row.get("password")?;
let password_format: Option<String> = row.get("passwordFormat")?;
let roles_str: Option<String> = row.get("roles")?;
let password = match password_format.as_deref() {
Some("SHA1") => Password::Sha1(password),
_ => Password::Plain(password),
};
let roles = roles_str
.unwrap_or_default()
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
users.insert(name, User { password, roles, deactivated: false });
}
Ok(users)
}
fn load_mounts(conn: &Connection) -> Result<BTreeMap<String, Mount>> {
let mut stmt = conn.prepare("SELECT deviceId, mountPoint, description FROM acl_mounts")?;
let mut rows = stmt.query([])?;
let mut mounts = BTreeMap::new();
while let Some(row) = rows.next()? {
let device_id: String = row.get("deviceId")?;
let mount_point: String = row.get("mountPoint")?;
let description: String = row.get("description")?;
mounts.insert(device_id, Mount { mount_point, description });
}
Ok(mounts)
}
fn fix_azure_role_prefix(name: String) -> String {
name.strip_prefix("azure:").map(ToString::to_string).unwrap_or(name)
}
fn load_roles(conn: &Connection) -> Result<BTreeMap<String, Role>> {
let mut stmt = conn.prepare("SELECT name, roles, profile FROM acl_roles")?;
let mut roles = stmt.query_map([], |row| {
let name: String = fix_azure_role_prefix(row.get(0)?);
let roles_str: String = row.get(1)?;
let profile_str: Option<String> = row.get(2).ok();
let role_list: Vec<String> = roles_str
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
let profile = if let Some(s) = profile_str {
if !s.trim().is_empty() {
match serde_json::from_str::<ProfileValue>(&s) {
Ok(p) => Some(p),
Err(e) => {
eprintln!("Failed to parse profile JSON for {name}: {e}");
Some(ProfileValue::Null)
}
}
} else {
None
}
} else {
None
};
Ok((name, Role {
roles: role_list,
access: vec![],
profile,
}))
})?
.collect::<Result<BTreeMap<String, Role>, _>>()?;
let mut stmt = conn.prepare(
"SELECT role, path, method, accessRole, ruleNumber
FROM acl_access
ORDER BY role, ruleNumber ASC",
)?;
let access_rows = stmt.query_map([], |row| {
let role: String = fix_azure_role_prefix(row.get(0)?);
let path = row.get(1).map(|s: Option<String>| s.unwrap_or_default().trim().to_string())?;
let method = row.get(2).map(|s: Option<String>| s.unwrap_or_default().trim().to_string())?;
let grant: String = row.get(3)?;
let shv_ri = format!("{}:{}", if path.is_empty() { "**" } else { &path }, if method.is_empty() { "*" } else { &method });
let access_rule = AccessRule { shv_ri, grant };
if let Err(err) = shvbroker::brokerimpl::ParsedAccessRule::try_from(&access_rule) {
panic!("Cannot parse AccessRule from acl_access table, row: {row:?} error: {err}");
}
Ok((role, access_rule))
})?;
for row in access_rows {
let (role_name, access_rule) = row?;
if let Some(role) = roles.get_mut(&role_name) {
role.access.push(access_rule);
} else {
eprintln!("Warning: acl_access refers to undefined role '{role_name}'");
}
}
Ok(roles)
}
fn insert_map<T: Serialize>(
conn: &mut Connection,
table: &str,
map: &BTreeMap<String, T>,
) -> Result<()> {
let tx = conn.transaction()?;
{
let mut stmt =
tx.prepare(&format!("INSERT OR REPLACE INTO {table} (id, def) VALUES (?1, ?2)"))?;
for (key, value) in map {
let json = serde_json::to_string(value)
.map_err(|e| async_sqlite::rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
stmt.execute(params![key, json])?;
}
}
tx.commit()
}
fn init_output_schema(conn: &Connection) -> Result<()> {
conn.execute(
"CREATE TABLE IF NOT EXISTS users (id TEXT PRIMARY KEY, def TEXT NOT NULL)",
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS mounts (id TEXT PRIMARY KEY, def TEXT NOT NULL)",
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS roles (id TEXT PRIMARY KEY, def TEXT NOT NULL)",
[],
)?;
Ok(())
}
#[derive(Debug, Deserialize)]
pub struct LegacyBrokerConfig {
#[serde(default)]
pub app: AppConfig,
pub server: Option<ServerConfig>,
pub sqlconfig: Option<SqlConfig>,
pub masters: Option<MastersConfig>,
pub ldap: Option<LdapConfig>,
pub azure: Option<AzureConfig>,
}
fn default_broker_id() -> String {
"broker.local".to_string()
}
impl Default for AppConfig {
fn default() -> Self {
Self { broker_id: default_broker_id() }
}
}
#[derive(Debug, Deserialize)]
pub struct AppConfig {
#[serde(rename = "brokerId", default = "default_broker_id")]
pub broker_id: String,
}
#[derive(Debug, Deserialize)]
pub struct ServerConfig {
pub port: Option<u16>,
#[serde(rename = "sslPort")]
pub ssl_port: Option<u16>,
pub websocket: Option<WebsocketConfig>,
pub ssl: Option<SslConfig>,
}
#[derive(Debug, Deserialize)]
pub struct WebsocketConfig {
pub port: Option<u16>,
#[serde(rename = "sslport")]
pub ssl_port: Option<u16>,
}
fn default_ssl_key() -> String {
"server.key".into()
}
fn default_ssl_cert() -> String {
"server.crt".into()
}
#[derive(Debug, Deserialize)]
pub struct SslConfig {
#[serde(default = "default_ssl_key")]
pub key: String,
#[serde(default = "default_ssl_cert")]
pub cert: String,
}
#[derive(Clone,Debug, Deserialize)]
pub struct SqlConfig {
#[serde(default)]
pub enabled: bool,
pub database: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct MastersConfig {
#[serde(default)]
pub enabled: bool,
pub connections: BTreeMap<String, MasterConnection>,
}
#[derive(Debug, Deserialize)]
pub struct MasterConnection {
#[serde(default)]
pub enabled: bool,
#[serde(rename = "exportedShvPath")]
pub exported_shv_path: Option<String>,
pub login: Option<LoginConfig>,
pub server: Option<MasterServerConfig>,
pub device: Option<DeviceConfig>,
pub rpc: Option<RpcConfig>,
}
#[derive(Debug, Deserialize)]
pub struct LoginConfig {
pub user: String,
pub password: String,
#[serde(rename = "type")]
pub password_type: Option<String>, }
#[derive(Debug, Deserialize)]
pub struct MasterServerConfig {
pub host: String,
#[serde(rename = "peerVerify")]
pub peer_verify: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct DeviceConfig {
pub id: Option<String>,
#[serde(rename = "mountPoint")]
pub mount_point: Option<String>,
#[serde(rename = "idFile")]
pub id_file: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct RpcConfig {
#[serde(rename = "protocolType")]
pub protocol_type: Option<String>,
#[serde(rename = "reconnectInterval")]
pub reconnect_interval: Option<u64>,
#[serde(rename = "heartbeatInterval")]
pub heartbeat_interval: Option<u64>,
}
#[derive(Debug, Deserialize)]
pub struct LdapConfig {
pub hostname: String,
#[serde(rename = "searchBaseDN")]
pub search_base_dn: String,
#[serde(rename = "searchAttrs")]
pub search_attrs: Option<Vec<String>>,
#[serde(rename = "groupMapping")]
pub group_mapping: Option<Vec<[String; 2]>>,
pub username: String,
pub password: String,
}
#[derive(Debug, Deserialize)]
pub struct AzureConfig {
#[serde(rename = "groupMapping")]
pub group_mapping: Option<Vec<[String; 2]>>,
#[serde(rename = "clientId")]
pub client_id: String,
#[serde(rename = "authorizeUrl")]
pub authorize_url: String,
#[serde(rename = "tokenUrl")]
pub token_url: String,
pub scopes: Option<String>,
}
impl From<LegacyBrokerConfig> for BrokerConfig {
fn from(cfg: LegacyBrokerConfig) -> Self {
let mut listen = Vec::new();
if let Some(server) = &cfg.server {
if let Some(port) = server.port
&& let Ok(url) = Url::parse(&format!("tcp://127.0.0.1:{port}")) {
listen.push(Listen { url });
}
if let Some(ssl_port) = server.ssl_port
&& let Ok(mut url) = Url::parse(&format!("ssl://0.0.0.0:{ssl_port}"))
&& let Some(ssl) = &server.ssl {
url.query_pairs_mut()
.append_pair("cert", &ssl.cert)
.append_pair("key", &ssl.key);
listen.push(Listen { url });
}
if let Some(ws) = &server.websocket {
if let Some(port) = ws.port
&& let Ok(url) = Url::parse(&format!("ws://127.0.0.1:{port}")) {
listen.push(Listen { url });
}
if let Some(ssl_port) = ws.ssl_port
&& let Ok(mut url) = Url::parse(&format!("wss://0.0.0.0:{ssl_port}"))
&& let Some(ssl) = &server.ssl {
url.query_pairs_mut()
.append_pair("cert", &ssl.cert)
.append_pair("key", &ssl.key);
listen.push(Listen { url });
}
}
}
let mut connections = Vec::new();
if let Some(masters_cfg) = cfg.masters && masters_cfg.enabled {
for (name, mconn) in masters_cfg.connections {
let connection_settings = ConnectionMountSettings {
exported_shv_root: "".into(),
imported_shv_root: "".into(),
mount_point: "".into(),
exported_root_user: "broker".to_string(),
};
let base_host = mconn
.server
.as_ref()
.map(|s| s.host.clone())
.unwrap_or_else(|| "tcp://127.0.0.1".to_string());
let normalized_host = if base_host.contains("://") {
base_host
} else {
format!("tcp://{}", base_host)
};
let mut url = Url::parse(&normalized_host)
.unwrap_or_else(|_| Url::parse("tcp://127.0.0.1").unwrap());
if let Some(login) = &mconn.login {
if url.set_username(&login.user).is_err() {
eprintln!("Cannot set username {user} for URL {url}", user = login.user);
}
url.query_pairs_mut()
.append_pair("password", &login.password);
}
let heartbeat_interval = mconn
.rpc
.as_ref()
.and_then(|r| r.heartbeat_interval)
.map(Duration::from_secs)
.unwrap_or_else(|| Duration::from_secs(60));
let reconnect_interval = mconn
.rpc
.as_ref()
.and_then(|r| r.reconnect_interval)
.map(Duration::from_secs);
let client = ClientConfig {
url,
device_id: mconn.device.as_ref().and_then(|d| d.id.clone()),
mount: mconn.device.as_ref().and_then(|d| d.mount_point.clone()),
heartbeat_interval,
reconnect_interval,
};
connections.push(BrokerConnectionConfig {
name,
enabled: mconn.enabled,
connection_settings,
client,
});
}
}
let azure = cfg.azure.map(|az| {
let group_mapping = az
.group_mapping
.unwrap_or_default()
.into_iter()
.map(|[native_group, shv_group]| (native_group, vec![shv_group]))
.collect::<Vec<_>>();
BrokerAzureConfig {
group_mapping,
client_id: az.client_id,
authorize_url: az.authorize_url,
token_url: az.token_url,
scopes: az
.scopes
.map(|s| s.split_whitespace().map(|x| x.to_string()).collect())
.unwrap_or_default(),
}
});
let data_directory = cfg.sqlconfig.as_ref().and_then(|sql| {
sql.database
.as_ref()
.and_then(|db| Path::new(db).parent())
.map(|p| p.to_string_lossy().to_string())
});
BrokerConfig {
name: Some(cfg.app.broker_id),
listen,
use_access_db: cfg.sqlconfig.as_ref().is_some_and(|s| s.enabled),
shv2_compatibility: false,
time_broadcast: false,
data_directory,
connections,
access: shvbroker::config::AccessConfig::default(),
tunnelling: shvbroker::config::TunnellingConfig::default(),
azure,
google_auth: None,
}
}
}
#[derive(clap::Parser, Debug)]
#[command(
name = "migrate_legacy_data",
about = "A tool for converting legacy C++ shvbroker config file and access database to the format used by shvbroker-rs"
)]
struct Args {
#[arg(long)]
legacy_config: String,
#[arg(long)]
result_config: Option<String>,
}
fn main() -> shvrpc::Result<()> {
let args = Args::parse();
let legacy_config_cpon = std::fs::read_to_string(&args.legacy_config)?;
let legacy_config: LegacyBrokerConfig = shvproto::from_rpcvalue(&RpcValue::from_cpon(legacy_config_cpon)?)?;
let legacy_sql_config = legacy_config.sqlconfig.clone();
let mut broker_config: BrokerConfig = legacy_config.into();
let config_dir = Path::new(&args.legacy_config).parent().unwrap_or_else(|| Path::new("."));
let result_config = args.result_config.map_or_else(|| Path::new(config_dir).join("shvbroker.yml"), |path| path.into());
println!("Migrating config file from: {from} to: {to}", from = args.legacy_config, to = result_config.to_str().unwrap_or_default());
std::fs::write(result_config, serde_yaml::to_string(&broker_config)?)?;
if broker_config.use_access_db {
let data_dir = Path::new(&broker_config.data_directory.unwrap_or_default()).to_owned();
let data_dir = if data_dir.is_relative() {
config_dir.join(data_dir)
} else {
data_dir
};
println!("data dir: {data_dir:?}");
broker_config.data_directory = Some(data_dir.to_string_lossy().into());
let legacy_db_file_name = if let Some(legacy_sql_config) = legacy_sql_config
&& let Some(db) = legacy_sql_config.database {
Path::new(&db)
.file_name()
.map_or_else(|| Path::new("shvbroker.cfg.db"), |file_name| Path::new(file_name))
.to_owned()
} else {
Path::new("shvbroker.cfg.db").to_owned()
};
let legacy_db_path = data_dir.join(legacy_db_file_name);
let new_db_path = data_dir.join("shvbroker.sqlite");
println!("Migrating the access database from: {from} to: {to}",
from = legacy_db_path.to_string_lossy(),
to = new_db_path.to_string_lossy()
);
smol::block_on(async {
let input_conn = ClientBuilder::new().path(legacy_db_path).flags(OpenFlags::SQLITE_OPEN_READ_ONLY).open().await?;
let output_conn = ClientBuilder::new().path(new_db_path).open().await?;
let (users, mounts, roles) = input_conn.conn(|input_conn| {
let users = load_users(input_conn)?;
let mounts = load_mounts(input_conn)?;
let roles = load_roles(input_conn)?;
Ok((users, mounts, roles))
}).await?;
output_conn.conn_mut(move |output_conn| {
init_output_schema(output_conn)?;
insert_map(output_conn, "users", &users)?;
insert_map(output_conn, "mounts", &mounts)?;
insert_map(output_conn, "roles", &roles)?;
Ok(())
}).await?;
shvrpc::Result::Ok(())
})?;
}
Ok(())
}