hubro-sdk 1.0.0

Hubro Platform SDK crate
use crate::records::{HealthRecord, Sampleable};
use rand::SeedableRng;
use rand::prelude::*;
use rand::rngs::SmallRng;
use serde::de;
use serde::{Deserialize, Serialize};
use std::ffi::CStr;
use std::mem;
use std::os::raw::{c_char, c_void};
use wasip1;

/// Represents the platform where the plugin is currently running.
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub enum Platform {
    /// Apple iOS.
    IOS,
    /// Google Android.
    Android,
    /// JavaScript/Browser environment.
    JS,
    /// Unknown or unsupported platform.
    Unknown,
}

/// Host functions imported from the `hubro_sdk` WASM module for legacy core Wasm engines.
#[cfg(not(all(target_os = "wasi", target_env = "p2")))]
#[link(wasm_import_module = "hubro_sdk")]
unsafe extern "C" {
    fn get_health_connect_records(record_type: i32, from: i32, to: i32) -> *mut c_char;
    fn get_health_connect_number_of_records(record_type: i32, from: i32, to: i32) -> i32;
    fn print_line(nf_name: *mut c_char);
    fn get_platform() -> *mut c_char;
    fn get_usage_data(from: i32, to: i32) -> *mut c_char;
    fn get_usage_data_number_of_records(from: i32, to: i32) -> i32;
    fn upload_server_key(ptr: *const u8, len: usize) -> i32;
    fn host_load_model(path_ptr: *const u8, path_len: usize) -> i32;
    fn host_run_inference(
        model_id: i32,
        input_ptr: *const f32,
        input_len: i32,
        shape_ptr: *const i32,
        shape_len: i32,
        output_ptr: *mut f32,
        output_len: i32,
    ) -> i32;
}

/// Allocates memory on the WASM heap and returns a pointer to it.
#[unsafe(no_mangle)]
pub extern "C" fn allocate(size: usize) -> *mut c_void {
    let mut buf = Vec::with_capacity(size);
    let ptr = buf.as_mut_ptr();
    mem::forget(buf);
    ptr as *mut c_void
}

/// Deallocates memory previously allocated by [`allocate`].
#[unsafe(no_mangle)]
pub extern "C" fn dealloc(ptr: *mut c_void, cap: usize) {
    unsafe {
        let _buf = Vec::from_raw_parts(ptr, 0, cap);
    }
}

/// Fetches the count of records from the host, retrying if the host is busy.
pub fn fetch_records_count(record_type: i32, from: i32, to: i32) -> i32 {
    let mut s;
    loop {
        #[cfg(all(target_os = "wasi", target_env = "p2"))]
        {
            s = crate::local::hubro_sdk::mobile_host::get_health_connect_number_of_records(record_type, from, to);
        }

        #[cfg(not(all(target_os = "wasi", target_env = "p2")))]
        {
            s = unsafe { get_health_connect_number_of_records(record_type, from, to) };
        }

        if s != -1 {
            break;
        }
    }
    s
}

/// Fetches health records of type `T` from the host and deserializes them from JSON.
fn fetch_records<T: de::DeserializeOwned + HealthRecord>(
    record_type: i32,
    from: i32,
    to: i32,
) -> Vec<T> {
    if fetch_records_count(record_type, from, to) > 0 {
        let subject_str: String;

        #[cfg(all(target_os = "wasi", target_env = "p2"))]
        {
            subject_str = crate::local::hubro_sdk::mobile_host::get_health_connect_records(record_type, from, to);
        }

        #[cfg(not(all(target_os = "wasi", target_env = "p2")))]
        {
            unsafe {
                let s = get_health_connect_records(record_type, from, to);
                let json_data = CStr::from_ptr(s).to_bytes().to_vec();
                subject_str = std::str::from_utf8(&json_data).unwrap().to_string();
            }
        }

        return serde_json::from_str::<Vec<T>>(&subject_str).unwrap();
    }
    Vec::<T>::new()
}

/// Retrieves health records of type `T` from the host for a given time range.
pub fn get_health_records<T: de::DeserializeOwned + HealthRecord>(from: i32, to: i32) -> Vec<T> {
    fetch_records::<T>(T::IDENTIFIER, from, to)
}

/// Retrieves the count of usage data records from the host for a given time range.
pub fn get_usage_data_count(from: i32, to: i32) -> i32 {
    let mut s;
    loop {
        #[cfg(all(target_os = "wasi", target_env = "p2"))]
        {
            s = crate::local::hubro_sdk::mobile_host::get_usage_data_number_of_records(from, to);
        }

        #[cfg(not(all(target_os = "wasi", target_env = "p2")))]
        {
            s = unsafe { get_usage_data_number_of_records(from, to) };
        }

        if s != -1 {
            break;
        }
    }
    s
}

/// Retrieves usage data records of type `T` from the host for a given time range.
pub fn get_usage_data_records<T: de::DeserializeOwned>(from: i32, to: i32) -> Vec<T> {
    let subject_str: String;

    #[cfg(all(target_os = "wasi", target_env = "p2"))]
    {
        subject_str = crate::local::hubro_sdk::mobile_host::get_usage_data(from, to);
    }

    #[cfg(not(all(target_os = "wasi", target_env = "p2")))]
    {
        unsafe {
            let ptr = get_usage_data(from, to);
            if ptr.is_null() {
                return Vec::new();
            }
            let json_data = CStr::from_ptr(ptr).to_bytes().to_vec();
            subject_str = std::str::from_utf8(&json_data).unwrap().to_string();
        }
    }

    serde_json::from_str::<Vec<T>>(&subject_str).unwrap_or_default()
}

/// Generates sample health records for testing or fallback purposes.
pub fn generate_sample_records<T: Serialize + Sampleable>(from: i32, to: i32) -> Vec<T> {
    let mut seed = [0u8; 32];

    if let Ok(time) = unsafe { wasip1::clock_time_get(wasip1::CLOCKID_REALTIME, 1_000) } {
        let b = time.to_le_bytes();
        seed[0..8].copy_from_slice(&b);
    }

    let _ = getrandom::getrandom(&mut seed[8..32]);
    let mut rng = SmallRng::from_seed(seed);
    let mut records = Vec::new();

    let start_time = chrono::DateTime::from_timestamp(from as i64, 0).unwrap().naive_utc().and_utc();
    let end_time = chrono::DateTime::from_timestamp(to as i64, 0).unwrap().naive_utc().and_utc();
    let duration = end_time.signed_duration_since(start_time);
    let days = duration.num_days();

    for day in 0..days {
        let base_timestamp = start_time + chrono::Duration::days(day);
        let num_records = rng.random_range(3..8);

        for hour in 0..num_records {
            let timestamp = base_timestamp + chrono::Duration::hours(hour * 4);
            let record = T::generate_sample(&mut rng, timestamp, chrono::Duration::hours(4));
            records.push(record);
        }
    }
    records
}

/// Prints a debug string to the host's output console.
pub fn debug_print_line(output: &str) {
    #[cfg(all(target_os = "wasi", target_env = "p2"))]
    {
        crate::local::hubro_sdk::mobile_host::print_line(output);
    }

    #[cfg(not(all(target_os = "wasi", target_env = "p2")))]
    {
        let size = output.len();
        let ptr = allocate(size + 1) as *mut c_char;
        unsafe {
            std::ptr::copy(output.as_ptr(), ptr as *mut u8, size);
            *ptr.add(size) = 0;
            print_line(ptr);
        }
    }
}

/// Detects the current platform by querying the host environment.
pub fn get_current_platform() -> Platform {
    let s: String;

    #[cfg(all(target_os = "wasi", target_env = "p2"))]
    {
        s = crate::local::hubro_sdk::mobile_host::get_platform();
    }

    #[cfg(not(all(target_os = "wasi", target_env = "p2")))]
    {
        unsafe {
            let ptr = get_platform();
            if ptr.is_null() {
                return Platform::Unknown;
            }
            s = CStr::from_ptr(ptr).to_string_lossy().to_string();
        }
    }

    let s_lower = s.to_lowercase();
    if s_lower.contains("ios") {
        Platform::IOS
    } else if s_lower.contains("android") {
        Platform::Android
    } else if s_lower.contains("js") {
        Platform::JS
    } else {
        Platform::Unknown
    }
}

/// Loads a machine learning model from the specified path on the host.
pub fn load_model(path: &str) -> i32 {
    #[cfg(all(target_os = "wasi", target_env = "p2"))]
    {
        crate::local::hubro_sdk::mobile_host::host_load_model(path)
    }

    #[cfg(not(all(target_os = "wasi", target_env = "p2")))]
    {
        unsafe { host_load_model(path.as_ptr(), path.len()) }
    }
}

/// Runs inference on a loaded machine learning model.
pub fn run_inference(model_id: i32, input: &[f32], shape: &[i32], output: &mut [f32]) -> i32 {
    #[cfg(all(target_os = "wasi", target_env = "p2"))]
    {
        let result = crate::local::hubro_sdk::mobile_host::host_run_inference(
            model_id,
            input,
            shape,
            output.len() as i32
        );
        let copy_len = result.len().min(output.len());
        output[..copy_len].copy_from_slice(&result[..copy_len]);
        0
    }

    #[cfg(not(all(target_os = "wasi", target_env = "p2")))]
    {
        unsafe {
            host_run_inference(
                model_id,
                input.as_ptr(),
                input.len() as i32,
                shape.as_ptr(),
                shape.len() as i32,
                output.as_mut_ptr(),
                output.len() as i32,
            )
        }
    }
}

/// Uploads a TFHE server key to the host for use in encrypted computations.
#[cfg(all(feature = "mobile", not(all(target_os = "wasi", target_env = "p2"))))]
pub fn upload_key(server_key: &tfhe::ServerKey) -> i32 {
    let key = serialize_keys(server_key);
    unsafe { upload_server_key(key.as_ptr(), key.len()) }
}

/// Backup method helper for WASIp2 logic mapping to avoid code execution locks.
#[cfg(all(feature = "mobile", all(target_os = "wasi", target_env = "p2")))]
pub fn upload_server_key(server_key: &[u8]) -> i32 {
    crate::local::hubro_sdk::mobile_host::upload_server_key(server_key)
}

/// Returns the current time in nanoseconds since the Unix epoch using the WASI realtime clock.
pub fn get_current_time_nanos() -> u64 {
    unsafe {
        wasip1::clock_time_get(wasip1::CLOCKID_REALTIME, 1_000).expect("WASI clock_time_get failed")
    }
}

/// Serializes a TFHE `ServerKey` into a byte vector using `bincode`.
#[cfg(all(feature = "mobile", not(all(target_os = "wasi", target_env = "p2"))))]
pub fn serialize_keys(server_key: &tfhe::ServerKey) -> Vec<u8> {
    let config = bincode::config::standard();
    bincode::serde::encode_to_vec(server_key, config)
        .expect("Failed to serialize TFHE server key using Serde bridge")
}

/// Quantizes a 32-bit float to a 16-bit unsigned integer within a given range.
pub fn quantize_float_to_u16(weight: f32, x_min: f32, x_max: f32) -> u16 {
    let q_min = 0.0_f32;
    let q_max = 65535.0_f32;
    let scale = (x_max - x_min) / (q_max - q_min);
    let clamped = weight.clamp(x_min, x_max);
    (((clamped - x_min) / scale) + q_min).round() as u16
}