1use js_sys::Uint8Array;
2use ort::session::SessionOutputs;
3use ort_sys::{OrtErrorCode, stub::Error};
4
5use crate::{
6 binding,
7 tensor::{SyncDirection, ValueExt},
8 util::value_to_string
9};
10
11pub const SESSION_SENTINEL: [u8; 4] = [0xFC, 0x86, 0xA5, 0x01];
12
13#[repr(C)]
14pub struct Session {
15 sentinel: [u8; 4],
16 pub js: binding::InferenceSession,
17 pub disable_sync: bool
18}
19
20impl Session {
21 pub async fn from_url(uri: &str, options: &SessionOptions) -> Result<Self, Error> {
22 Ok(Session {
23 sentinel: SESSION_SENTINEL,
24 js: binding::InferenceSession::create_from_uri(uri, &options.js)
25 .await
26 .map_err(|e| Error::new(OrtErrorCode::ORT_FAIL, value_to_string(&e)))?,
27 disable_sync: options.disable_sync
28 })
29 }
30
31 pub async fn from_bytes(bytes: &[u8], options: &SessionOptions) -> Result<Self, Error> {
32 Ok(Session {
33 sentinel: SESSION_SENTINEL,
34 js: binding::InferenceSession::create_from_bytes(
35 &unsafe { Uint8Array::view(bytes) },
37 &options.js
38 )
39 .await
40 .map_err(|e| Error::new(OrtErrorCode::ORT_FAIL, value_to_string(&e)))?,
41 disable_sync: options.disable_sync
42 })
43 }
44}
45
46pub struct RunOptions {}
47
48impl RunOptions {
49 pub const fn new() -> Self {
50 RunOptions {}
51 }
52}
53
54pub async fn sync_outputs(outputs: &mut SessionOutputs<'_>) -> crate::Result<()> {
66 for (_, mut value) in outputs.iter_mut() {
67 value.sync(SyncDirection::Rust).await?;
68 }
69 Ok(())
70}
71
72#[derive(Clone)]
73pub struct SessionOptions {
74 pub js: binding::SessionOptions,
75 pub disable_sync: bool
76}
77
78impl SessionOptions {
79 pub fn new() -> Self {
80 Self {
81 js: binding::SessionOptions::default(),
82 disable_sync: true
83 }
84 }
85}