Skip to main content

DqnAgent

Struct DqnAgent 

Source
pub struct DqnAgent<E, Enc, Act, B, Buf = CircularBuffer<<E as Environment>::Observation, <E as Environment>::Action>>{ /* private fields */ }
Expand description

A DQN agent.

Implements ε-greedy action selection, experience replay, and TD learning with a target network. Generic over:

  • E: the environment type (must satisfy rl_traits::Environment)
  • Enc: the observation encoder (converts E::Observation to tensors)
  • Act: the action mapper (converts E::Action to/from integer indices)
  • B: the Burn backend (e.g. NdArray, Wgpu)
  • Buf: the replay buffer (defaults to CircularBuffer — swap for PER etc.)

§Usage

let agent = DqnAgent::new(encoder, action_mapper, config, device, seed);

Then hand it to DqnRunner, which drives the training loop.

Implementations§

Source§

impl<E, Enc, Act, B> DqnAgent<E, Enc, Act, B>

Source

pub fn new( encoder: Enc, action_mapper: Act, config: DqnConfig, device: B::Device, seed: u64, ) -> Self

Create a new agent using the default CircularBuffer replay buffer.

Buffer capacity is taken from config.buffer_capacity.

Source§

impl<E, Enc, Act, B, Buf> DqnAgent<E, Enc, Act, B, Buf>

Source

pub fn new_with_buffer( encoder: Enc, action_mapper: Act, config: DqnConfig, device: B::Device, seed: u64, buffer: Buf, ) -> Self

Create a new agent with a custom replay buffer.

Use this to swap in prioritised experience replay or any other ReplayBuffer implementation in place of the default CircularBuffer.

Source

pub fn observe( &mut self, experience: Experience<E::Observation, E::Action>, ) -> bool

Store a transition in the replay buffer and potentially run a gradient update.

Called by the runner after every environment step. Returns true if a gradient update was performed this step.

Source

pub fn epsilon(&self) -> f64

The current exploration probability.

Source

pub fn total_steps(&self) -> usize

Total environment steps observed so far.

Source

pub fn act_epsilon_greedy( &self, obs: &E::Observation, rng: &mut impl Rng, ) -> E::Action

Select an action using ε-greedy policy.

Source

pub fn save(&self, path: impl AsRef<Path>) -> Result<(), RecorderError>

Save the online network weights to a file.

Uses Burn’s CompactRecorder (MessagePack format). The recorder appends its own extension to the path, so save("run/cartpole") produces run/cartpole.mpk.

Only the online network weights are saved — the target network, replay buffer, and optimizer state are not included. This is sufficient for inference. To resume training, call load followed by set_total_steps to restore the correct epsilon.

Source

pub fn load(self, path: impl AsRef<Path>) -> Result<Self, RecorderError>

Load network weights from a file into this agent.

Loads into the online network and immediately syncs the target network. Takes self by value and returns the updated agent so you can chain with the constructor:

let agent = DqnAgent::new(...).load("run/cartpole")?;
Source

pub fn into_policy(self) -> DqnPolicy<E, Enc, Act, B::InnerBackend>

Convert this trained agent into an inference-only DqnPolicy.

Strips all training state (optimizer, buffer, RNG) and downcasts the network to B::InnerBackend (no autodiff). Use this when training is complete and you want a lightweight policy for evaluation or deployment.

let policy = runner.into_agent().into_policy();
let action = policy.act(&obs);
Source

pub fn set_total_steps(&mut self, steps: usize)

Override the internal step counter.

Useful when resuming training — restores epsilon to the correct value for the point in training where the checkpoint was saved.

Trait Implementations§

Source§

impl<E, Enc, Act, B, Buf> Policy<<E as Environment>::Observation, <E as Environment>::Action> for DqnAgent<E, Enc, Act, B, Buf>

Source§

fn act(&self, obs: &E::Observation) -> E::Action

Greedy action selection (no exploration).

Use this for evaluation. For training, use act_epsilon_greedy.

Auto Trait Implementations§

§

impl<E, Enc, Act, B, Buf> Freeze for DqnAgent<E, Enc, Act, B, Buf>
where Buf: Freeze, Enc: Freeze, Act: Freeze, <B as Backend>::Device: Freeze,

§

impl<E, Enc, Act, B, Buf = CircularBuffer<<E as Environment>::Observation, <E as Environment>::Action>> !RefUnwindSafe for DqnAgent<E, Enc, Act, B, Buf>

§

impl<E, Enc, Act, B, Buf> Send for DqnAgent<E, Enc, Act, B, Buf>
where Buf: Send, Enc: Send, Act: Send, E: Send,

§

impl<E, Enc, Act, B, Buf = CircularBuffer<<E as Environment>::Observation, <E as Environment>::Action>> !Sync for DqnAgent<E, Enc, Act, B, Buf>

§

impl<E, Enc, Act, B, Buf> Unpin for DqnAgent<E, Enc, Act, B, Buf>

§

impl<E, Enc, Act, B, Buf> UnsafeUnpin for DqnAgent<E, Enc, Act, B, Buf>
where Buf: UnsafeUnpin, Enc: UnsafeUnpin, Act: UnsafeUnpin, <B as Backend>::Device: UnsafeUnpin,

§

impl<E, Enc, Act, B, Buf = CircularBuffer<<E as Environment>::Observation, <E as Environment>::Action>> !UnwindSafe for DqnAgent<E, Enc, Act, B, Buf>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoComptime for T

Source§

fn comptime(self) -> Self

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more