pub mod certs;
pub mod config;
pub mod device;
pub mod git;
pub mod ota;
use clap::Parser;
use pyrinas_shared::{ota::OTAPackageVersion, OtaLink};
use serde::{Deserialize, Serialize};
use std::{net::TcpStream, num};
use thiserror::Error;
use tungstenite::{http, http::Request, protocol::WebSocket, stream::MaybeTlsStream};
#[derive(Debug, Error)]
pub enum Error {
#[error("{source}")]
Error {
#[from]
source: config::Error,
},
#[error("http error: {source}")]
HttpError {
#[from]
source: http::Error,
},
#[error("websocket handshake error {source}")]
WebsocketError {
#[from]
source: tungstenite::Error,
},
#[error("semver error: {source}")]
SemVerError {
#[from]
source: semver::Error,
},
#[error("parse error: {source}")]
ParseError {
#[from]
source: num::ParseIntError,
},
#[error("err: {0}")]
CustomError(String),
#[error("ota error: {source}")]
OtaError {
#[from]
source: ota::Error,
},
#[error("{source}")]
CertsError {
#[from]
source: certs::Error,
},
}
#[derive(Parser, Debug)]
#[clap(version)]
pub struct OtaCmd {
#[clap(subcommand)]
pub subcmd: OtaSubCommand,
}
#[derive(Parser, Debug)]
#[clap(version)]
pub struct CertCmd {
#[clap(subcommand)]
pub subcmd: CertSubcommand,
}
#[derive(Parser, Debug)]
#[clap(version)]
pub enum CertSubcommand {
Ca,
Server,
Device(CertDevice),
}
#[derive(Parser, Debug)]
#[clap(version)]
pub struct CertDevice {
id: Option<String>,
#[clap(long, short)]
provision: bool,
#[clap(long, default_value = certs::DEFAULT_MAC_PORT )]
port: String,
#[clap(long, short)]
tag: Option<u32>,
}
#[derive(Parser, Debug)]
#[clap(version)]
pub enum OtaSubCommand {
Add(OtaAdd),
Link(OtaLink),
Unlink(OtaLink),
Remove(OtaRemove),
ListGroups,
ListImages,
}
#[derive(Parser, Debug)]
#[clap(version)]
pub struct OtaAdd {
#[clap(long, short)]
pub force: bool,
#[clap(long, short)]
pub device_id: Option<String>,
#[clap(long, default_value = pyrinas_shared::DEFAULT_OTA_VERSION)]
pub ota_version: u8,
}
#[derive(Parser, Debug)]
#[clap(version)]
pub struct OtaRemove {
pub image_id: String,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct CertConfig {
pub domain: String,
pub organization: String,
pub country: String,
pub pfx_pass: String,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct CertEntry {
pub tag: u32,
pub ca_cert: Option<String>,
pub private_key: Option<String>,
pub pub_key: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Config {
pub url: String,
pub secure: bool,
pub authkey: String,
pub cert: CertConfig,
pub alts: Option<Vec<CertEntry>>,
}
#[derive(Parser, Debug, Serialize, Deserialize)]
#[clap(version)]
pub struct ConfigCmd {
#[clap(subcommand)]
pub subcmd: ConfigSubCommand,
}
#[derive(Parser, Debug, Serialize, Deserialize)]
#[clap(version)]
pub enum ConfigSubCommand {
Show(Show),
Init,
}
#[derive(Parser, Debug, Serialize, Deserialize)]
#[clap(version)]
pub struct Show {}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OTAManifest {
pub version: OTAPackageVersion,
pub file: String,
pub force: bool,
}
pub fn get_socket(config: &Config) -> Result<WebSocket<MaybeTlsStream<TcpStream>>, Error> {
if !config.secure {
println!("WARNING! Not using secure web socket connection!");
}
let full_uri = format!(
"ws{}://{}/socket",
match config.secure {
true => "s",
false => "",
},
config.url
);
let req = Request::builder()
.uri(full_uri)
.header("ApiKey", config.authkey.clone())
.body(())?;
let (socket, _response) = tungstenite::connect(req)?;
Ok(socket)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Once;
static INIT: Once = Once::new();
fn setup() {
INIT.call_once(|| env_logger::init());
}
#[test]
fn get_ota_package_version_success_with_dirty() {
setup();
let ver = "0.2.1-19-g09db6ef-dirty";
let res = git::get_ota_package_version(ver);
assert!(res.is_ok());
let (package_ver, dirty) = res.unwrap();
assert!(dirty);
assert_eq!(
package_ver,
OTAPackageVersion {
major: 0,
minor: 2,
patch: 1,
commit: 19,
hash: [
'g' as u8, '0' as u8, '9' as u8, 'd' as u8, 'b' as u8, '6' as u8, 'e' as u8,
'f' as u8
]
}
)
}
#[test]
fn get_ota_package_version_success_clean() {
setup();
let ver = "0.2.1-19-g09db6ef";
let res = git::get_ota_package_version(ver);
assert!(res.is_ok());
let (package_ver, dirty) = res.unwrap();
assert!(!dirty);
assert_eq!(
package_ver,
OTAPackageVersion {
major: 0,
minor: 2,
patch: 1,
commit: 19,
hash: [
'g' as u8, '0' as u8, '9' as u8, 'd' as u8, 'b' as u8, '6' as u8, 'e' as u8,
'f' as u8
]
}
)
}
#[test]
fn get_ota_package_version_failure_dirty() {
setup();
let ver = "0.2.1-g09db6ef-dirty";
let res = git::get_ota_package_version(ver);
assert!(res.is_err());
}
#[allow(dead_code)]
fn get_git_describe_success() {
setup();
let res = git::get_git_describe();
assert!(res.is_ok());
log::info!("res: {}", res.unwrap());
}
}