Skip to main content

Device

Enum Device 

Source
pub enum Device {
    Cpu,
}
Expand description

Represents a compute device where tensors can be allocated and operations executed.

Variants§

§

Cpu

CPU device (always available).

Implementations§

Source§

impl Device

Source

pub fn is_available(self) -> bool

Returns true if this device is available on the current system.

Source

pub const fn is_cpu(self) -> bool

Returns true if this is a CPU device.

Source

pub const fn is_gpu(self) -> bool

Returns true if this is a GPU device.

Examples found in repository?
examples/mnist_training.rs (line 126)
34fn main() {
35    println!("=== AxonML - MNIST Training (LeNet) ===\n");
36
37    // -------------------------------------------------------------------------
38    // Device Detection
39    // -------------------------------------------------------------------------
40
41    // Detect device
42    #[cfg(feature = "cuda")]
43    let device = {
44        let cuda = Device::Cuda(0);
45        if cuda.is_available() {
46            println!("GPU detected: using CUDA device 0");
47            cuda
48        } else {
49            println!("CUDA feature enabled but no GPU available, using CPU");
50            Device::Cpu
51        }
52    };
53    #[cfg(not(feature = "cuda"))]
54    let device = {
55        println!("Using CPU (compile with --features cuda for GPU)");
56        Device::Cpu
57    };
58
59    // -------------------------------------------------------------------------
60    // Dataset and DataLoader Setup
61    // -------------------------------------------------------------------------
62
63    // 1. Create dataset
64    let num_train = 2000;
65    let num_test = 400;
66    println!("\n1. Creating SyntheticMNIST dataset ({num_train} train, {num_test} test)...");
67    let train_dataset = SyntheticMNIST::new(num_train);
68    let test_dataset = SyntheticMNIST::new(num_test);
69
70    // 2. Create DataLoader
71    let batch_size = 64;
72    println!("2. Creating DataLoader (batch_size={batch_size})...");
73    let train_loader = DataLoader::new(train_dataset, batch_size);
74    let test_loader = DataLoader::new(test_dataset, batch_size);
75    println!("   Training batches: {}", train_loader.len());
76
77    // -------------------------------------------------------------------------
78    // Model, Optimizer, and Loss
79    // -------------------------------------------------------------------------
80
81    // 3. Create LeNet model and move to device
82    println!("3. Creating LeNet model...");
83    let model = LeNet::new();
84    model.to_device(device);
85    let params = model.parameters();
86    let total_params: usize = params
87        .iter()
88        .map(|p| p.variable().data().to_vec().len())
89        .sum();
90    println!(
91        "   Parameters: {} ({} total weights)",
92        params.len(),
93        total_params
94    );
95    println!("   Device: {:?}", device);
96
97    // 4. Create optimizer and loss
98    println!("4. Creating Adam optimizer (lr=0.001) + CrossEntropyLoss...");
99    let mut optimizer = Adam::new(params, 0.001);
100    let criterion = CrossEntropyLoss::new();
101
102    // -------------------------------------------------------------------------
103    // Training Loop
104    // -------------------------------------------------------------------------
105
106    // 5. Training loop
107    let epochs = 10;
108    println!("5. Training for {epochs} epochs...\n");
109
110    let train_start = Instant::now();
111
112    for epoch in 0..epochs {
113        let epoch_start = Instant::now();
114        let mut total_loss = 0.0;
115        let mut correct = 0usize;
116        let mut total = 0usize;
117        let mut batch_count = 0;
118
119        for batch in train_loader.iter() {
120            let bs = batch.data.shape()[0];
121
122            // Reshape to [N, 1, 28, 28] and create Variable
123            let input_data = batch.data.to_vec();
124            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
125            let input = Variable::new(
126                if device.is_gpu() {
127                    input_tensor.to_device(device).unwrap()
128                } else {
129                    input_tensor
130                },
131                true,
132            );
133
134            // Target: convert one-hot [N, 10] to class indices [N]
135            let target_onehot = batch.targets.to_vec();
136            let mut target_indices = vec![0.0f32; bs];
137            for i in 0..bs {
138                let offset = i * 10;
139                let mut max_idx = 0;
140                let mut max_val = f32::NEG_INFINITY;
141                for c in 0..10 {
142                    if target_onehot[offset + c] > max_val {
143                        max_val = target_onehot[offset + c];
144                        max_idx = c;
145                    }
146                }
147                target_indices[i] = max_idx as f32;
148            }
149            let target_tensor = Tensor::from_vec(target_indices.clone(), &[bs]).unwrap();
150            let target = Variable::new(
151                if device.is_gpu() {
152                    target_tensor.to_device(device).unwrap()
153                } else {
154                    target_tensor
155                },
156                false,
157            );
158
159            // Forward pass
160            let output = model.forward(&input);
161
162            // Cross-entropy loss
163            let loss = criterion.compute(&output, &target);
164
165            let loss_val = loss.data().to_vec()[0];
166            total_loss += loss_val;
167            batch_count += 1;
168
169            // Compute training accuracy
170            let out_data = output.data().to_vec();
171            for i in 0..bs {
172                let offset = i * 10;
173                let mut pred = 0;
174                let mut pred_val = f32::NEG_INFINITY;
175                for c in 0..10 {
176                    if out_data[offset + c] > pred_val {
177                        pred_val = out_data[offset + c];
178                        pred = c;
179                    }
180                }
181                if pred == target_indices[i] as usize {
182                    correct += 1;
183                }
184                total += 1;
185            }
186
187            // Backward pass
188            loss.backward();
189
190            // Update weights
191            optimizer.step();
192            optimizer.zero_grad();
193        }
194
195        let epoch_time = epoch_start.elapsed();
196        let avg_loss = total_loss / batch_count as f32;
197        let accuracy = 100.0 * correct as f32 / total as f32;
198        let samples_per_sec = total as f64 / epoch_time.as_secs_f64();
199
200        println!(
201            "   Epoch {:2}/{}: Loss={:.4}  Acc={:.1}%  ({:.0} samples/s, {:.2}s)",
202            epoch + 1,
203            epochs,
204            avg_loss,
205            accuracy,
206            samples_per_sec,
207            epoch_time.as_secs_f64(),
208        );
209    }
210
211    let train_time = train_start.elapsed();
212    println!("\n   Total training time: {:.2}s", train_time.as_secs_f64());
213
214    // -------------------------------------------------------------------------
215    // Test Evaluation
216    // -------------------------------------------------------------------------
217
218    // 6. Test evaluation
219    println!("\n6. Evaluating on test set...");
220
221    // Disable gradient computation for evaluation
222    let (correct, total) = no_grad(|| {
223        let mut correct = 0usize;
224        let mut total = 0usize;
225
226        for batch in test_loader.iter() {
227            let bs = batch.data.shape()[0];
228
229            let input_data = batch.data.to_vec();
230            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
231            let input = Variable::new(
232                if device.is_gpu() {
233                    input_tensor.to_device(device).unwrap()
234                } else {
235                    input_tensor
236                },
237                false,
238            );
239
240            let target_onehot = batch.targets.to_vec();
241            let output = model.forward(&input);
242            let out_data = output.data().to_vec();
243
244            for i in 0..bs {
245                // Prediction: argmax of output
246                let offset = i * 10;
247                let mut pred = 0;
248                let mut pred_val = f32::NEG_INFINITY;
249                for c in 0..10 {
250                    if out_data[offset + c] > pred_val {
251                        pred_val = out_data[offset + c];
252                        pred = c;
253                    }
254                }
255
256                // True label: argmax of one-hot target
257                let mut true_label = 0;
258                let mut true_val = f32::NEG_INFINITY;
259                for c in 0..10 {
260                    if target_onehot[i * 10 + c] > true_val {
261                        true_val = target_onehot[i * 10 + c];
262                        true_label = c;
263                    }
264                }
265
266                if pred == true_label {
267                    correct += 1;
268                }
269                total += 1;
270            }
271        }
272
273        (correct, total)
274    });
275
276    let test_accuracy = 100.0 * correct as f32 / total as f32;
277    println!(
278        "   Test Accuracy: {}/{} ({:.2}%)",
279        correct, total, test_accuracy
280    );
281
282    println!("\n=== Training Complete! ===");
283    println!("   Device: {:?}", device);
284    println!("   Final test accuracy: {:.2}%", test_accuracy);
285}
Source

pub const fn index(self) -> usize

Returns the device index for GPU devices, or 0 for CPU.

Source

pub const fn device_type(self) -> &'static str

Returns the name of this device type.

Source

pub const fn cpu() -> Device

Returns the default CPU device.

Source§

impl Device

Source

pub fn capabilities(self) -> DeviceCapabilities

Returns the capabilities of this device.

Trait Implementations§

Source§

impl Clone for Device

Source§

fn clone(&self) -> Device

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for Device

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
Source§

impl Default for Device

Source§

fn default() -> Device

Returns the “default value” for a type. Read more
Source§

impl Display for Device

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
Source§

impl Hash for Device

Source§

fn hash<__H>(&self, state: &mut __H)
where __H: Hasher,

Feeds this value into the given Hasher. Read more
1.3.0 · Source§

fn hash_slice<H>(data: &[Self], state: &mut H)
where H: Hasher, Self: Sized,

Feeds a slice of this type into the given Hasher. Read more
Source§

impl PartialEq for Device

Source§

fn eq(&self, other: &Device) -> bool

Tests for self and other values to be equal, and is used by ==.
1.0.0 · Source§

fn ne(&self, other: &Rhs) -> bool

Tests for !=. The default implementation is almost always sufficient, and should not be overridden without very good reason.
Source§

impl Copy for Device

Source§

impl Eq for Device

Source§

impl StructuralPartialEq for Device

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<Q, K> Equivalent<K> for Q
where Q: Eq + ?Sized, K: Borrow<Q> + ?Sized,

Source§

fn equivalent(&self, key: &K) -> bool

Checks if this value is equivalent to the given key. Read more
Source§

impl<Q, K> Equivalent<K> for Q
where Q: Eq + ?Sized, K: Borrow<Q> + ?Sized,

Source§

fn equivalent(&self, key: &K) -> bool

Compare self to key and return true if they are equal.
Source§

impl<Q, K> Equivalent<K> for Q
where Q: Eq + ?Sized, K: Borrow<Q> + ?Sized,

Source§

fn equivalent(&self, key: &K) -> bool

Checks if this value is equivalent to the given key. Read more
Source§

impl<Q, K> Equivalent<K> for Q
where Q: Eq + ?Sized, K: Borrow<Q> + ?Sized,

Source§

fn equivalent(&self, key: &K) -> bool

Checks if this value is equivalent to the given key. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> PolicyExt for T
where T: ?Sized,

Source§

fn and<P, B, E>(self, other: P) -> And<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow only if self and other return Action::Follow. Read more
Source§

fn or<P, B, E>(self, other: P) -> Or<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow if either self or other returns Action::Follow. Read more
Source§

impl<R, P> ReadPrimitive<R> for P
where R: Read + ReadEndian<P>, P: Default,

Source§

fn read_from_little_endian(read: &mut R) -> Result<Self, Error>

Read this value from the supplied reader. Same as ReadEndian::read_from_little_endian().
Source§

fn read_from_big_endian(read: &mut R) -> Result<Self, Error>

Read this value from the supplied reader. Same as ReadEndian::read_from_big_endian().
Source§

fn read_from_native_endian(read: &mut R) -> Result<Self, Error>

Read this value from the supplied reader. Same as ReadEndian::read_from_native_endian().
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> ToString for T
where T: Display + ?Sized,

Source§

fn to_string(&self) -> String

Converts the given value to a String. Read more
Source§

impl<T> ToStringFallible for T
where T: Display,

Source§

fn try_to_string(&self) -> Result<String, TryReserveError>

ToString::to_string, but without panic on OOM.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more