pub struct DqnAgent<E, Enc, Act, B, Buf = CircularBuffer<<E as Environment>::Observation, <E as Environment>::Action>>{ /* private fields */ }Expand description
A DQN agent.
Implements ε-greedy action selection, experience replay, and TD learning with a target network. Generic over:
E: the environment type (must satisfyrl_traits::Environment)Enc: the observation encoder (convertsE::Observationto tensors)Act: the action mapper (convertsE::Actionto/from integer indices)B: the Burn backend (e.g.NdArray,Wgpu)Buf: the replay buffer (defaults toCircularBuffer— swap for PER etc.)
§Usage
let agent = DqnAgent::new(encoder, action_mapper, config, device, seed);Then hand it to DqnRunner, which drives the training loop.
Implementations§
Source§impl<E, Enc, Act, B> DqnAgent<E, Enc, Act, B>where
E: Environment,
E::Observation: Clone + Send + Sync + 'static,
E::Action: Clone + Send + Sync + 'static,
Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
Act: DiscreteActionMapper<E::Action>,
B: AutodiffBackend,
impl<E, Enc, Act, B> DqnAgent<E, Enc, Act, B>where
E: Environment,
E::Observation: Clone + Send + Sync + 'static,
E::Action: Clone + Send + Sync + 'static,
Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
Act: DiscreteActionMapper<E::Action>,
B: AutodiffBackend,
Source§impl<E, Enc, Act, B, Buf> DqnAgent<E, Enc, Act, B, Buf>where
E: Environment,
E::Observation: Clone + Send + Sync + 'static,
E::Action: Clone + Send + Sync + 'static,
Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
Act: DiscreteActionMapper<E::Action>,
B: AutodiffBackend,
Buf: ReplayBuffer<E::Observation, E::Action>,
impl<E, Enc, Act, B, Buf> DqnAgent<E, Enc, Act, B, Buf>where
E: Environment,
E::Observation: Clone + Send + Sync + 'static,
E::Action: Clone + Send + Sync + 'static,
Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
Act: DiscreteActionMapper<E::Action>,
B: AutodiffBackend,
Buf: ReplayBuffer<E::Observation, E::Action>,
Sourcepub fn new_with_buffer(
encoder: Enc,
action_mapper: Act,
config: DqnConfig,
device: B::Device,
seed: u64,
buffer: Buf,
) -> Self
pub fn new_with_buffer( encoder: Enc, action_mapper: Act, config: DqnConfig, device: B::Device, seed: u64, buffer: Buf, ) -> Self
Create a new agent with a custom replay buffer.
Use this to swap in prioritised experience replay or any other
ReplayBuffer implementation in place of the default CircularBuffer.
Sourcepub fn observe(
&mut self,
experience: Experience<E::Observation, E::Action>,
) -> bool
pub fn observe( &mut self, experience: Experience<E::Observation, E::Action>, ) -> bool
Store a transition in the replay buffer and potentially run a gradient update.
Called by the runner after every environment step. Returns true if
a gradient update was performed this step.
Sourcepub fn total_steps(&self) -> usize
pub fn total_steps(&self) -> usize
Total environment steps observed so far.
Sourcepub fn act_epsilon_greedy(
&self,
obs: &E::Observation,
rng: &mut impl Rng,
) -> E::Action
pub fn act_epsilon_greedy( &self, obs: &E::Observation, rng: &mut impl Rng, ) -> E::Action
Select an action using ε-greedy policy.
Sourcepub fn save(&self, path: impl AsRef<Path>) -> Result<(), RecorderError>
pub fn save(&self, path: impl AsRef<Path>) -> Result<(), RecorderError>
Save the online network weights to a file.
Uses Burn’s CompactRecorder (MessagePack format). The recorder appends
its own extension to the path, so save("run/cartpole") produces
run/cartpole.mpk.
Only the online network weights are saved — the target network,
replay buffer, and optimizer state are not included. This is sufficient
for inference. To resume training, call load followed by
set_total_steps to restore the correct epsilon.
Sourcepub fn load(self, path: impl AsRef<Path>) -> Result<Self, RecorderError>
pub fn load(self, path: impl AsRef<Path>) -> Result<Self, RecorderError>
Load network weights from a file into this agent.
Loads into the online network and immediately syncs the target network.
Takes self by value and returns the updated agent so you can chain
with the constructor:
let agent = DqnAgent::new(...).load("run/cartpole")?;Sourcepub fn into_policy(self) -> DqnPolicy<E, Enc, Act, B::InnerBackend>
pub fn into_policy(self) -> DqnPolicy<E, Enc, Act, B::InnerBackend>
Convert this trained agent into an inference-only DqnPolicy.
Strips all training state (optimizer, buffer, RNG) and downcasts the
network to B::InnerBackend (no autodiff). Use this when training is
complete and you want a lightweight policy for evaluation or deployment.
let policy = runner.into_agent().into_policy();
let action = policy.act(&obs);Sourcepub fn set_total_steps(&mut self, steps: usize)
pub fn set_total_steps(&mut self, steps: usize)
Override the internal step counter.
Useful when resuming training — restores epsilon to the correct value for the point in training where the checkpoint was saved.
Trait Implementations§
Source§impl<E, Enc, Act, B, Buf> Policy<<E as Environment>::Observation, <E as Environment>::Action> for DqnAgent<E, Enc, Act, B, Buf>where
E: Environment,
Enc: ObservationEncoder<E::Observation, B>,
Act: DiscreteActionMapper<E::Action>,
B: AutodiffBackend,
Buf: ReplayBuffer<E::Observation, E::Action>,
impl<E, Enc, Act, B, Buf> Policy<<E as Environment>::Observation, <E as Environment>::Action> for DqnAgent<E, Enc, Act, B, Buf>where
E: Environment,
Enc: ObservationEncoder<E::Observation, B>,
Act: DiscreteActionMapper<E::Action>,
B: AutodiffBackend,
Buf: ReplayBuffer<E::Observation, E::Action>,
Source§fn act(&self, obs: &E::Observation) -> E::Action
fn act(&self, obs: &E::Observation) -> E::Action
Greedy action selection (no exploration).
Use this for evaluation. For training, use act_epsilon_greedy.
Auto Trait Implementations§
impl<E, Enc, Act, B, Buf> Freeze for DqnAgent<E, Enc, Act, B, Buf>
impl<E, Enc, Act, B, Buf = CircularBuffer<<E as Environment>::Observation, <E as Environment>::Action>> !RefUnwindSafe for DqnAgent<E, Enc, Act, B, Buf>
impl<E, Enc, Act, B, Buf> Send for DqnAgent<E, Enc, Act, B, Buf>
impl<E, Enc, Act, B, Buf = CircularBuffer<<E as Environment>::Observation, <E as Environment>::Action>> !Sync for DqnAgent<E, Enc, Act, B, Buf>
impl<E, Enc, Act, B, Buf> Unpin for DqnAgent<E, Enc, Act, B, Buf>where
Buf: Unpin,
Enc: Unpin,
Act: Unpin,
<B as Backend>::Device: Unpin,
E: Unpin,
<B as Backend>::FloatTensorPrimitive: Unpin,
<B as Backend>::QuantizedTensorPrimitive: Unpin,
<<B as AutodiffBackend>::InnerBackend as Backend>::FloatTensorPrimitive: Unpin,
<<B as AutodiffBackend>::InnerBackend as Backend>::QuantizedTensorPrimitive: Unpin,
impl<E, Enc, Act, B, Buf> UnsafeUnpin for DqnAgent<E, Enc, Act, B, Buf>
impl<E, Enc, Act, B, Buf = CircularBuffer<<E as Environment>::Observation, <E as Environment>::Action>> !UnwindSafe for DqnAgent<E, Enc, Act, B, Buf>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more