use core::fmt::Write as _;
use heapless::String;
use serde::Deserialize;
use crate::error::{Error, ErrorCode};
use crate::utils::storage::WriteBuf;
use crate::utils::sync::IfMutex;
use super::{OtaImageMeta, OtaImages, OtaImagesRegistry, OtaQueryOutcome, MAX_FILE_DESIGNATOR};
pub const DCL_MAINNET: &str = "https://on.dcl.csa-iot.org";
pub const DCL_TESTNET: &str = "https://on.test-net.dcl.csa-iot.org";
const MAX_VERSIONS: usize = 16;
const MAX_URL: usize = 256;
pub trait OtaHttp {
async fn get(
&self,
url: &str,
range: Option<(u64, usize)>,
buf: &mut [u8],
) -> Result<usize, Error>;
}
impl<T> OtaHttp for &T
where
T: OtaHttp,
{
async fn get(
&self,
url: &str,
range: Option<(u64, usize)>,
buf: &mut [u8],
) -> Result<usize, Error> {
T::get(self, url, range, buf).await
}
}
pub struct DclImages<'a, H> {
http: H,
base_url: &'a str,
locked: IfMutex<Locked<'a>>,
}
struct Locked<'a> {
scratch: &'a mut [u8],
cache: Cached,
}
struct Cached {
designator: String<MAX_FILE_DESIGNATOR>,
ota_url: String<MAX_URL>,
size: u64,
}
impl<'a, H> DclImages<'a, H> {
pub const fn new(http: H, base_url: &'a str, scratch: &'a mut [u8]) -> Self {
assert!(
scratch.len() > MAX_URL,
"DclImages scratch buffer must be larger than MAX_URL"
);
Self {
http,
base_url,
locked: IfMutex::new(Locked {
scratch,
cache: Cached {
designator: String::new(),
ota_url: String::new(),
size: 0,
},
}),
}
}
pub const fn mainnet(http: H, scratch: &'a mut [u8]) -> Self {
Self::new(http, DCL_MAINNET, scratch)
}
pub const fn testnet(http: H, scratch: &'a mut [u8]) -> Self {
Self::new(http, DCL_TESTNET, scratch)
}
}
impl<H: OtaHttp> DclImages<'_, H> {
async fn resolve_into(&self, locked: &mut Locked<'_>, fd: &[u8]) -> Option<()> {
if !fd.is_empty() && locked.cache.designator.as_bytes() == fd {
return Some(());
}
let (vid, pid, version) = parse_designator(fd)?;
let (url_buf, json_buf) = carve(locked.scratch)?;
let url = build_detail_url(url_buf, self.base_url, vid, pid, version)?;
let n = self.http.get(url, None, json_buf).await.ok()?;
let (resp, _) = serde_json_core::from_slice::<DetailResponse>(&json_buf[..n]).ok()?;
let mv = resp.model_version;
let fd = core::str::from_utf8(fd).ok()?;
store(&mut locked.cache, fd, mv.ota_url, mv.ota_file_size)
}
async fn resolve_query<'b>(
&self,
vendor_id: u16,
product_id: u16,
current_version: u32,
designator_buf: &'b mut [u8],
) -> Option<OtaImageMeta<'b>> {
let mut guard = self.locked.lock().await;
let locked = &mut *guard;
let (url_buf, json_buf) = carve(locked.scratch)?;
let candidate = {
let url = build_list_url(&mut *url_buf, self.base_url, vendor_id, product_id)?;
let n = self.http.get(url, None, &mut *json_buf).await.ok()?;
let (resp, _) = serde_json_core::from_slice::<VersionsResponse>(&json_buf[..n]).ok()?;
select_version(&resp.model_versions.software_versions, current_version)?
};
let url = build_detail_url(
&mut *url_buf,
self.base_url,
vendor_id,
product_id,
candidate,
)?;
let n = self.http.get(url, None, &mut *json_buf).await.ok()?;
let (resp, _) = serde_json_core::from_slice::<DetailResponse>(&json_buf[..n]).ok()?;
let mv = resp.model_version;
if !applicable(
mv.valid,
mv.min_applicable,
mv.max_applicable,
current_version,
) {
return None;
}
let designator =
write_designator(designator_buf, vendor_id, product_id, mv.software_version)?;
store(&mut locked.cache, designator, mv.ota_url, mv.ota_file_size)?;
Some(OtaImageMeta {
version: mv.software_version,
file_designator: designator,
update_token: designator.as_bytes(),
size: Some(mv.ota_file_size),
user_consent_needed: false,
})
}
}
fn store(cache: &mut Cached, designator: &str, ota_url: &str, size: u64) -> Option<()> {
cache.designator.clear();
cache.designator.push_str(designator).ok()?;
cache.ota_url.clear();
cache.ota_url.push_str(ota_url).ok()?;
cache.size = size;
Some(())
}
impl<H: OtaHttp> OtaImagesRegistry for DclImages<'_, H> {
async fn query<'b>(
&self,
vendor_id: u16,
product_id: u16,
current_version: u32,
_requestor_can_consent: bool,
designator_buf: &'b mut [u8],
) -> OtaQueryOutcome<'b> {
match self
.resolve_query(vendor_id, product_id, current_version, designator_buf)
.await
{
Some(image) => OtaQueryOutcome::Available(image),
None => OtaQueryOutcome::NotAvailable,
}
}
}
impl<H: OtaHttp> OtaImages for DclImages<'_, H> {
async fn size(&self, file_designator: &[u8]) -> Option<u64> {
let mut guard = self.locked.lock().await;
let locked = &mut *guard;
self.resolve_into(locked, file_designator).await?;
(locked.cache.designator.as_bytes() == file_designator).then_some(locked.cache.size)
}
async fn read(
&self,
file_designator: &[u8],
offset: u64,
buf: &mut [u8],
) -> Result<usize, Error> {
let mut url = String::<MAX_URL>::new();
{
let mut guard = self.locked.lock().await;
let locked = &mut *guard;
self.resolve_into(locked, file_designator)
.await
.ok_or(ErrorCode::InvalidData)?;
if locked.cache.designator.as_bytes() != file_designator {
return Err(ErrorCode::InvalidData.into());
}
url.push_str(&locked.cache.ota_url)
.map_err(|_| ErrorCode::NoSpace)?;
}
self.http.get(&url, Some((offset, buf.len())), buf).await
}
}
fn carve(scratch: &mut [u8]) -> Option<(&mut [u8], &mut [u8])> {
(scratch.len() > MAX_URL).then(|| scratch.split_at_mut(MAX_URL))
}
fn build_list_url<'u>(url_buf: &'u mut [u8], base: &str, vid: u16, pid: u16) -> Option<&'u str> {
let len = {
let mut wb = WriteBuf::new(&mut *url_buf);
write!(wb, "{base}/dcl/model/versions/{vid}/{pid}").ok()?;
wb.as_slice().len()
};
core::str::from_utf8(&url_buf[..len]).ok()
}
fn build_detail_url<'u>(
url_buf: &'u mut [u8],
base: &str,
vid: u16,
pid: u16,
version: u32,
) -> Option<&'u str> {
let len = {
let mut wb = WriteBuf::new(&mut *url_buf);
write!(wb, "{base}/dcl/model/versions/{vid}/{pid}/{version}").ok()?;
wb.as_slice().len()
};
core::str::from_utf8(&url_buf[..len]).ok()
}
#[derive(Deserialize)]
struct VersionsResponse {
#[serde(rename = "modelVersions")]
model_versions: VersionList,
}
#[derive(Deserialize)]
struct VersionList {
#[serde(rename = "softwareVersions")]
software_versions: heapless::Vec<u32, MAX_VERSIONS>,
}
#[derive(Deserialize)]
struct DetailResponse<'a> {
#[serde(borrow, rename = "modelVersion")]
model_version: ModelVersion<'a>,
}
#[derive(Deserialize)]
struct ModelVersion<'a> {
#[serde(rename = "softwareVersion")]
software_version: u32,
#[serde(rename = "softwareVersionValid")]
valid: bool,
#[serde(borrow, rename = "otaUrl")]
ota_url: &'a str,
#[serde(rename = "otaFileSize")]
ota_file_size: u64,
#[serde(rename = "minApplicableSoftwareVersion")]
min_applicable: u32,
#[serde(rename = "maxApplicableSoftwareVersion")]
max_applicable: u32,
}
fn select_version(versions: &[u32], current: u32) -> Option<u32> {
versions.iter().copied().filter(|v| *v > current).max()
}
fn applicable(valid: bool, min: u32, max: u32, current: u32) -> bool {
valid && current >= min && current <= max
}
fn write_designator(buf: &mut [u8], vid: u16, pid: u16, version: u32) -> Option<&str> {
let len = {
let mut wb = WriteBuf::new(&mut *buf);
write!(wb, "{vid:04X}-{pid:04X}-{version}").ok()?;
wb.as_slice().len()
};
core::str::from_utf8(&buf[..len]).ok()
}
fn parse_designator(fd: &[u8]) -> Option<(u16, u16, u32)> {
let s = core::str::from_utf8(fd).ok()?;
let mut parts = s.split('-');
let vid = u16::from_str_radix(parts.next()?, 16).ok()?;
let pid = u16::from_str_radix(parts.next()?, 16).ok()?;
let version = parts.next()?.parse::<u32>().ok()?;
if parts.next().is_some() {
return None;
}
Some((vid, pid, version))
}
#[cfg(test)]
mod tests {
use core::future::Future;
use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use super::*;
const VERSIONS_JSON: &[u8] =
br#"{"modelVersions":{"vid":4660,"pid":22136,"softwareVersions":[1,2,10]}}"#;
const DETAIL_JSON: &[u8] = br#"{"modelVersion":{"vid":4660,"pid":22136,"softwareVersion":10,
"softwareVersionString":"1.0.10","cdVersionNumber":1,"firmwareInformation":"",
"softwareVersionValid":true,"otaUrl":"https://fw.example.com/fw/10.ota",
"otaFileSize":2500,"otaChecksum":"abcd","otaChecksumType":1,
"minApplicableSoftwareVersion":1,"maxApplicableSoftwareVersion":9,
"releaseNotesUrl":"","creator":"cosmos1xyz"}}"#;
#[test]
fn select_version_picks_highest_newer() {
assert_eq!(select_version(&[1, 2, 10], 1), Some(10));
assert_eq!(select_version(&[1, 2, 10], 9), Some(10));
assert_eq!(select_version(&[1, 2, 10], 10), None);
assert_eq!(select_version(&[], 0), None);
}
#[test]
fn applicable_checks_validity_and_range() {
assert!(applicable(true, 1, 9, 1));
assert!(applicable(true, 1, 9, 9));
assert!(!applicable(false, 1, 9, 5));
assert!(!applicable(true, 5, 9, 1));
assert!(!applicable(true, 1, 9, 10));
}
#[test]
fn designator_round_trips() {
let mut buf = [0u8; 64];
let fd = write_designator(&mut buf, 0x1234, 0x5678, 10).unwrap();
assert_eq!(fd, "1234-5678-10");
assert_eq!(parse_designator(fd.as_bytes()), Some((0x1234, 0x5678, 10)));
assert_eq!(parse_designator(b"nonsense"), None);
assert_eq!(parse_designator(b"1234-5678-10-extra"), None);
}
#[test]
fn parses_dcl_responses() {
let (versions, _) = serde_json_core::from_slice::<VersionsResponse>(VERSIONS_JSON).unwrap();
assert_eq!(versions.model_versions.software_versions, &[1, 2, 10]);
let (detail, _) = serde_json_core::from_slice::<DetailResponse>(DETAIL_JSON).unwrap();
let mv = detail.model_version;
assert_eq!(mv.software_version, 10);
assert!(mv.valid);
assert_eq!(mv.ota_url, "https://fw.example.com/fw/10.ota");
assert_eq!(mv.ota_file_size, 2500);
assert_eq!((mv.min_applicable, mv.max_applicable), (1, 9));
}
struct MockHttp {
firmware: heapless::Vec<u8, 4096>,
}
impl OtaHttp for MockHttp {
async fn get(
&self,
url: &str,
range: Option<(u64, usize)>,
buf: &mut [u8],
) -> Result<usize, Error> {
let body: &[u8] = if url.contains("fw.example.com") {
&self.firmware
} else {
let tail = url.split("/versions/").nth(1).unwrap_or("");
if tail.split('/').count() >= 3 {
DETAIL_JSON
} else {
VERSIONS_JSON
}
};
let (start, len) = range.unwrap_or((0, buf.len()));
let start = start as usize;
if start >= body.len() {
return Ok(0);
}
let end = (start + len).min(body.len());
let n = end - start;
buf[..n].copy_from_slice(&body[start..end]);
Ok(n)
}
}
fn block_on<F: Future>(fut: F) -> F::Output {
fn clone(_: *const ()) -> RawWaker {
RawWaker::new(core::ptr::null(), &VTABLE)
}
fn noop(_: *const ()) {}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, noop, noop, noop);
let waker = unsafe { Waker::from_raw(RawWaker::new(core::ptr::null(), &VTABLE)) };
let mut cx = Context::from_waker(&waker);
let mut fut = core::pin::pin!(fut);
loop {
if let Poll::Ready(v) = fut.as_mut().poll(&mut cx) {
return v;
}
}
}
#[test]
fn end_to_end_query_then_download() {
let firmware: heapless::Vec<u8, 4096> = (0..2500).map(|i| (i % 251) as u8).collect();
let mut scratch = [0u8; 2048];
let dcl = DclImages::new(
MockHttp {
firmware: firmware.clone(),
},
"https://on.dcl.csa-iot.org",
&mut scratch,
);
block_on(async {
let mut fd_buf = [0u8; 64];
let OtaQueryOutcome::Available(meta) =
dcl.query(0x1234, 0x5678, 1, false, &mut fd_buf).await
else {
panic!("expected an available image");
};
assert_eq!(meta.version, 10);
assert_eq!(meta.file_designator, "1234-5678-10");
assert_eq!(meta.update_token, b"1234-5678-10");
assert_eq!(meta.size, Some(2500));
assert!(!meta.user_consent_needed);
let fd = b"1234-5678-10";
assert_eq!(dcl.size(fd).await, Some(2500));
let mut out: heapless::Vec<u8, 4096> = heapless::Vec::new();
let mut rbuf = [0u8; 300];
let mut offset = 0u64;
loop {
let n = dcl.read(fd, offset, &mut rbuf).await.unwrap();
if n == 0 {
break;
}
out.extend_from_slice(&rbuf[..n]).unwrap();
offset += n as u64;
}
assert_eq!(out, firmware);
let mut fd_buf = [0u8; 64];
assert!(matches!(
dcl.query(0x1234, 0x5678, 10, false, &mut fd_buf).await,
OtaQueryOutcome::NotAvailable
));
});
}
}