pub struct CudaTrainer { /* private fields */ }Expand description
CUDA-accelerated training context
Manages GPU resources and provides high-level training operations.
Implementations§
Source§impl CudaTrainer
impl CudaTrainer
Sourcepub fn with_device(device_id: i32) -> Result<Self>
pub fn with_device(device_id: i32) -> Result<Self>
Create a new CUDA trainer on the specified GPU
Sourcepub fn context(&self) -> &Arc<CudaContext>
pub fn context(&self) -> &Arc<CudaContext>
Get the CUDA context
Sourcepub fn stream(&self) -> &CudaStream
pub fn stream(&self) -> &CudaStream
Get the CUDA stream
Sourcepub fn synchronize(&self) -> Result<()>
pub fn synchronize(&self) -> Result<()>
Synchronize the stream (wait for all operations to complete)
Sourcepub fn upload(&self, data: &[f32]) -> Result<GpuBuffer<f32>>
pub fn upload(&self, data: &[f32]) -> Result<GpuBuffer<f32>>
Allocate a GPU buffer from host data
Sourcepub fn zeros(&self, len: usize) -> Result<GpuBuffer<f32>>
pub fn zeros(&self, len: usize) -> Result<GpuBuffer<f32>>
Allocate a zero-initialized GPU buffer
Sourcepub fn free_memory_mb(&self) -> Option<u64>
pub fn free_memory_mb(&self) -> Option<u64>
Query free VRAM in MB (via cuMemGetInfo). Returns None if query fails.
Sourcepub fn download(&self, buffer: &GpuBuffer<f32>) -> Result<Vec<f32>>
pub fn download(&self, buffer: &GpuBuffer<f32>) -> Result<Vec<f32>>
Download GPU buffer to host
Sourcepub fn matmul_forward(
&self,
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
c: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
) -> Result<()>
pub fn matmul_forward( &self, a: &GpuBuffer<f32>, b: &GpuBuffer<f32>, c: &mut GpuBuffer<f32>, m: u32, k: u32, n: u32, ) -> Result<()>
Matrix multiply forward pass: C = A @ B
§Arguments
a: Input matrix (m × k)b: Weight matrix (k × n)c: Output matrix (m × n)m,k,n: Matrix dimensions
Sourcepub fn matmul_backward(
&self,
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
grad_c: &GpuBuffer<f32>,
grad_a: &mut GpuBuffer<f32>,
grad_b: &mut GpuBuffer<f32>,
m: u32,
k: u32,
n: u32,
) -> Result<()>
pub fn matmul_backward( &self, a: &GpuBuffer<f32>, b: &GpuBuffer<f32>, grad_c: &GpuBuffer<f32>, grad_a: &mut GpuBuffer<f32>, grad_b: &mut GpuBuffer<f32>, m: u32, k: u32, n: u32, ) -> Result<()>
Matrix multiply backward pass for weight gradients
Given C = A @ B, computes:
- grad_A = grad_C @ B^T
- grad_B = A^T @ grad_C
Sourcepub fn adamw_step(
&mut self,
params: &mut GpuBuffer<f32>,
grads: &GpuBuffer<f32>,
m_state: &mut GpuBuffer<f32>,
v_state: &mut GpuBuffer<f32>,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
) -> Result<()>
pub fn adamw_step( &mut self, params: &mut GpuBuffer<f32>, grads: &GpuBuffer<f32>, m_state: &mut GpuBuffer<f32>, v_state: &mut GpuBuffer<f32>, lr: f32, beta1: f32, beta2: f32, eps: f32, weight_decay: f32, ) -> Result<()>
AdamW optimizer step on GPU
Updates weights in-place using the AdamW algorithm.
Sourcepub fn clip_gradients(
&self,
grads: &mut GpuBuffer<f32>,
max_norm: f32,
) -> Result<()>
pub fn clip_gradients( &self, grads: &mut GpuBuffer<f32>, max_norm: f32, ) -> Result<()>
Apply gradient clipping
Sourcepub fn step_count(&self) -> u32
pub fn step_count(&self) -> u32
Get current optimizer step count
Sourcepub fn reset_step(&mut self)
pub fn reset_step(&mut self)
Reset optimizer step count (for new training run)
Sourcepub fn device_name(&self) -> String
pub fn device_name(&self) -> String
Get device name
Sourcepub fn total_memory(&self) -> usize
pub fn total_memory(&self) -> usize
Get total GPU memory in bytes
Trait Implementations§
Auto Trait Implementations§
impl Freeze for CudaTrainer
impl RefUnwindSafe for CudaTrainer
impl Send for CudaTrainer
impl Sync for CudaTrainer
impl Unpin for CudaTrainer
impl UnsafeUnpin for CudaTrainer
impl UnwindSafe for CudaTrainer
Blanket Implementations§
impl<T> Allocation for T
impl<T> Allocation for T
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> FmtForward for T
impl<T> FmtForward for T
Source§fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
self to use its Binary implementation when Debug-formatted.Source§fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
self to use its Display implementation when
Debug-formatted.Source§fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
self to use its LowerExp implementation when
Debug-formatted.Source§fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
self to use its LowerHex implementation when
Debug-formatted.Source§fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
self to use its Octal implementation when Debug-formatted.Source§fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
self to use its Pointer implementation when
Debug-formatted.Source§fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
self to use its UpperExp implementation when
Debug-formatted.Source§fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
self to use its UpperHex implementation when
Debug-formatted.Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pipe for Twhere
T: ?Sized,
impl<T> Pipe for Twhere
T: ?Sized,
Source§fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
Source§fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
Source§fn pipe_borrow_mut<'a, B, R>(
&'a mut self,
func: impl FnOnce(&'a mut B) -> R,
) -> R
fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
Source§fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
self, then passes self.as_ref() into the pipe function.Source§fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
self, then passes self.as_mut() into the pipe
function.Source§fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
self, then passes self.deref() into the pipe function.Source§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> PolicyExt for Twhere
T: ?Sized,
impl<T> PolicyExt for Twhere
T: ?Sized,
Source§impl<T> Tap for T
impl<T> Tap for T
Source§fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
Borrow<B> of a value. Read moreSource§fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
BorrowMut<B> of a value. Read moreSource§fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
AsRef<R> view of a value. Read moreSource§fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
AsMut<R> view of a value. Read moreSource§fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
.tap() only in debug builds, and is erased in release builds.Source§fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
.tap_mut() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
.tap_borrow() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
.tap_borrow_mut() only in debug builds, and is erased in release
builds.Source§fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
.tap_ref() only in debug builds, and is erased in release
builds.Source§fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
.tap_ref_mut() only in debug builds, and is erased in release
builds.Source§fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
.tap_deref() only in debug builds, and is erased in release
builds.