1#![no_std]
2
3use core::panic::PanicInfo;
4
5#[panic_handler]
6fn panic(_info: &PanicInfo<'_>) -> ! {
7 loop {}
8}
9
10pub mod network;
11pub mod tensor;
12pub mod scratch;
13pub mod rnn_format;
14pub mod rnn_api;
15pub mod crypto;
16pub mod conv3d;
17pub mod conv5d;
18pub mod sphere5d;
19pub mod activations;
20pub mod layers;
21pub mod engine;
22pub mod model_format;
23pub mod losses;
24pub mod metrics;
25pub mod initializers;
26pub mod inference;
27pub mod trainer;
28pub mod optimizers;
29pub mod schedulers;
30pub mod normalization;
31pub mod attention;
32pub mod quantization;
33pub mod model_config;
34pub mod runtime;
35pub mod sampling;
36pub mod kv_cache;
37pub mod rope;
38pub mod embeddings;
39pub mod lora;
40pub mod moe;
41pub mod beam_search;
42pub mod gradients;
43pub mod batching;
44pub mod profiler;
45
46mod public_api;
47
48pub use public_api::*;
49
50pub mod math {
51 use core::f32::consts::{PI, LN_2};
52
53 #[inline]
54 pub fn sqrtf(x: f32) -> f32 {
55 if x == 0.0 { return 0.0; }
56 if !(x > 0.0) { return f32::NAN; }
57 let xhalf = 0.5_f32 * x;
58 let mut i = x.to_bits();
59 i = 0x5f3759dfu32.wrapping_sub(i >> 1);
60 let mut y = f32::from_bits(i);
61 y = y * (1.5 - xhalf * y * y);
62 y = y * (1.5 - xhalf * y * y);
63 x * y
64 }
65
66 #[inline]
67 fn floorf(x: f32) -> f32 {
68 if x.is_nan() { return x; }
69 let t = x as i32 as f32;
70 if t > x { t - 1.0 } else { t }
71 }
72
73 #[inline]
74 fn ldexpf(x: f32, exp: i32) -> f32 {
75 if x == 0.0 { return 0.0; }
76 let bits = x.to_bits();
77 let sign = bits & 0x8000_0000;
78 let mant = bits & 0x007f_ffff;
79 let mut e = ((bits >> 23) & 0xff) as i32 - 127;
80 e += exp;
81 if e <= -127 { return 0.0; }
82 if e >= 128 { return core::f32::INFINITY; }
83 let new_bits = sign | (((e + 127) as u32) << 23) | mant;
84 f32::from_bits(new_bits)
85 }
86
87 #[inline]
88 pub fn expf(x: f32) -> f32 {
89 if x.is_nan() { return x; }
90 let x = if x > 88.0 { 88.0 } else if x < -88.0 { -88.0 } else { x };
91 let inv_ln2: f32 = 1.4426950408889634_f32;
92 let n = floorf(x * inv_ln2) as i32;
93 let r = x - (n as f32) * LN_2;
94 let r2 = r * r;
95 let r3 = r2 * r;
96 let r4 = r3 * r;
97 let r5 = r4 * r;
98 let approx = 1.0 + r + 0.5 * r2 + (1.0/6.0) * r3 + (1.0/24.0) * r4 + (1.0/120.0) * r5;
99 ldexpf(approx, n)
100 }
101
102 #[inline]
103 pub fn lnf(x: f32) -> f32 {
104 if x <= 0.0 { return f32::NAN; }
105 let bits = x.to_bits();
106 let e = ((bits >> 23) & 0xff) as i32 - 127;
107 let mant_bits = (bits & 0x007f_ffff) | 0x3f80_0000;
108 let m = f32::from_bits(mant_bits);
109 let y = (m - 1.0) / (m + 1.0);
110 let y2 = y * y;
111 let y3 = y2 * y;
112 let y5 = y3 * y2;
113 let y7 = y5 * y2;
114 let ln_m = 2.0 * (y + y3 / 3.0 + y5 / 5.0 + y7 / 7.0);
115 ln_m + (e as f32) * LN_2
116 }
117
118 #[inline]
119 pub fn powf(x: f32, y: f32) -> f32 {
120 if x <= 0.0 { return f32::NAN; }
121 expf(y * lnf(x))
122 }
123
124 #[inline]
125 pub fn sinf(mut x: f32) -> f32 {
126 let two_pi = 2.0 * PI;
127 x = x - roundf(x / two_pi) * two_pi;
128 let x2 = x * x;
129 let x3 = x2 * x;
130 let x5 = x3 * x2;
131 let x7 = x5 * x2;
132 x - x3 / 6.0 + x5 / 120.0 - x7 / 5040.0
133 }
134
135 #[inline]
136 pub fn cosf(mut x: f32) -> f32 {
137 let two_pi = 2.0 * PI;
138 x = x - roundf(x / two_pi) * two_pi;
139 let x2 = x * x;
140 let x4 = x2 * x2;
141 let x6 = x4 * x2;
142 1.0 - x2 / 2.0 + x4 / 24.0 - x6 / 720.0
143 }
144
145 #[inline]
146 pub fn tanhf(x: f32) -> f32 {
147 let e2 = expf(2.0 * x);
148 (e2 - 1.0) / (e2 + 1.0)
149 }
150
151 #[inline]
152 pub fn roundf(x: f32) -> f32 {
153 if x.is_nan() { return x; }
154 if x >= 0.0 { (x + 0.5) as i32 as f32 } else { (x - 0.5) as i32 as f32 }
155 }
156}