Skip to main content

DirichletMultinomial

Struct DirichletMultinomial 

Source
pub struct DirichletMultinomial { /* private fields */ }
Expand description

Dirichlet-Multinomial conjugate prior for categorical data.

Models probability parameters θ₁, …, θₖ for k mutually exclusive categories where Σθᵢ = 1 (probability simplex).

Prior: Dirichlet(α₁, …, αₖ) Likelihood: Multinomial(n, θ₁, …, θₖ) Posterior: Dirichlet(α₁ + n₁, …, αₖ + nₖ)

§Mathematical Foundation

Given observations in k categories with counts n₁, …, nₖ:

  • Prior: p(θ) = Dirichlet(α) ∝ ∏θᵢ^(αᵢ-1)
  • Likelihood: p(n|θ) = Multinomial(n|θ) ∝ ∏θᵢ^nᵢ
  • Posterior: p(θ|n) = Dirichlet(α + n)

where α + n means element-wise addition: (α₁ + n₁, …, αₖ + nₖ)

§Example

use aprender::bayesian::DirichletMultinomial;

// 3-category classification: [A, B, C]
let mut model = DirichletMultinomial::uniform(3);

// Observe counts: 10 A's, 5 B's, 3 C's
model.update(&[10, 5, 3]);

// Posterior probabilities
let probs = model.posterior_mean();
assert!((probs[0] - 10.0/21.0).abs() < 0.1); // P(A) ≈ 0.476

Implementations§

Source§

impl DirichletMultinomial

Source

pub fn uniform(k: usize) -> DirichletMultinomial

Creates a uniform prior Dirichlet(1, …, 1) for k categories.

This represents equal probability for all categories with minimal prior belief.

§Arguments
  • k - Number of categories (must be ≥ 2)
§Panics

Panics if k < 2.

§Example
use aprender::bayesian::DirichletMultinomial;

let prior = DirichletMultinomial::uniform(3);
assert_eq!(prior.alphas().len(), 3);
assert_eq!(prior.alphas()[0], 1.0);
Source

pub fn new(alphas: Vec<f32>) -> Result<DirichletMultinomial, AprenderError>

Creates an informative prior Dirichlet(α₁, …, αₖ) from prior belief.

§Arguments
  • alphas - Concentration parameters αᵢ > 0 for each category
§Interpretation
  • αᵢ: Pseudo-count for category i
  • Σαᵢ: Total pseudo-count (strength of prior belief)
  • αᵢ / Σαⱼ: Prior mean probability for category i
§Errors

Returns error if:

  • Any αᵢ ≤ 0
  • Fewer than 2 categories
§Example
use aprender::bayesian::DirichletMultinomial;

// Prior belief: category probabilities [0.5, 0.3, 0.2] with 10 pseudo-counts
let prior = DirichletMultinomial::new(vec![5.0, 3.0, 2.0]).expect("valid concentration parameters");
let mean = prior.posterior_mean();
assert!((mean[0] - 0.5).abs() < 0.01);
Source

pub fn alphas(&self) -> &[f32]

Returns the current concentration parameters.

Source

pub fn num_categories(&self) -> usize

Returns the number of categories.

Source

pub fn update(&mut self, counts: &[u32])

Updates the posterior with observed category counts (Bayesian update).

§Arguments
  • counts - Observed counts for each category
§Panics

Panics if counts.len() != num_categories().

§Example
use aprender::bayesian::DirichletMultinomial;

let mut model = DirichletMultinomial::uniform(3);
model.update(&[10, 5, 3]); // 10 A's, 5 B's, 3 C's

// Posterior is Dirichlet(1+10, 1+5, 1+3) = Dirichlet(11, 6, 4)
assert_eq!(model.alphas()[0], 11.0);
Source

pub fn posterior_mean(&self) -> Vec<f32>

Computes the posterior mean E[θ|data] for all categories.

Returns a vector where element i is E[θᵢ|data] = αᵢ / Σαⱼ.

§Example
use aprender::bayesian::DirichletMultinomial;

let mut model = DirichletMultinomial::uniform(3);
model.update(&[10, 5, 3]);

let mean = model.posterior_mean();
assert!((mean[0] - 11.0/21.0).abs() < 0.01); // (1+10)/(1+1+1+10+5+3)
assert!((mean.iter().sum::<f32>() - 1.0).abs() < 1e-6); // Sums to 1
Source

pub fn posterior_mode(&self) -> Option<Vec<f32>>

Computes the posterior mode (MAP estimate) for all categories.

Returns a vector where element i is (αᵢ - 1) / (Σαⱼ - k). Only defined when all αᵢ > 1.

§Returns
  • Some(mode) if all αᵢ > 1
  • None if any αᵢ ≤ 1 (distribution has no unique mode)
§Example
use aprender::bayesian::DirichletMultinomial;

let mut model = DirichletMultinomial::new(vec![2.0, 2.0, 2.0]).expect("valid concentration parameters");
model.update(&[10, 5, 3]);

let mode = model.posterior_mode().expect("mode exists when all alphas > 1");
assert!((mode[0] - 11.0/21.0).abs() < 0.01); // (12-1)/(24-3)
Source

pub fn posterior_variance(&self) -> Vec<f32>

Computes the posterior variance Var[θᵢ|data] for all categories.

Returns a vector where element i is: Var[θᵢ] = αᵢ(α₀ - αᵢ) / (α₀²(α₀ + 1)) where α₀ = Σαⱼ

§Example
use aprender::bayesian::DirichletMultinomial;

let mut model = DirichletMultinomial::uniform(3);
model.update(&[10, 5, 3]);

let variance = model.posterior_variance();
assert!(variance[0] > 0.0); // Positive uncertainty
Source

pub fn posterior_predictive(&self) -> Vec<f32>

Computes the posterior predictive distribution for the next observation.

For Dirichlet-Multinomial, the posterior predictive probabilities are: P(category i | data) = αᵢ / Σαⱼ

This equals the posterior mean.

§Example
use aprender::bayesian::DirichletMultinomial;

let mut model = DirichletMultinomial::uniform(3);
model.update(&[10, 5, 3]);

let pred = model.posterior_predictive();
assert!((pred[0] - 11.0/21.0).abs() < 0.01);
assert!((pred.iter().sum::<f32>() - 1.0).abs() < 1e-6);
Source

pub fn credible_intervals( &self, confidence: f32, ) -> Result<Vec<(f32, f32)>, AprenderError>

Computes (1-α) credible intervals for all category probabilities.

Returns a vector of (lower, upper) bounds for each category. Uses normal approximation for each marginal distribution.

§Arguments
  • confidence - Confidence level (e.g., 0.95 for 95% credible intervals)
§Errors

Returns error if confidence ∉ (0, 1).

§Example
use aprender::bayesian::DirichletMultinomial;

let mut model = DirichletMultinomial::uniform(3);
model.update(&[10, 5, 3]);

let intervals = model.credible_intervals(0.95).expect("valid confidence level");
let mean = model.posterior_mean();

// Mean should be within interval for each category
for i in 0..3 {
    assert!(intervals[i].0 < mean[i] && mean[i] < intervals[i].1);
}

Trait Implementations§

Source§

impl Clone for DirichletMultinomial

Source§

fn clone(&self) -> DirichletMultinomial

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for DirichletMultinomial

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> Downcast<T> for T

Source§

fn downcast(&self) -> &T

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<T> Upcast<T> for T

Source§

fn upcast(&self) -> Option<&T>

Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WasmNotSend for T
where T: Send,

Source§

impl<T> WasmNotSendSync for T

Source§

impl<T> WasmNotSync for T
where T: Sync,