block_graph/
ai.rs

1decompose_primitive!((),bool,char,f32,f64,i128,i16,i32,i64,i8,isize,u128,u16,u32,u64,u8,usize);
2decompose_tuple!((A,B),(A,B,C),(A,B,C,D),(A,B,C,D,E),(A,B,C,D,E,F),(A,B,C,D,E,F,G),(A,B,C,D,E,F,G,H));
3impl Decompose for Range<usize>{
4	fn compose(decomposition:Self::Decomposition)->Self{decomposition.0..decomposition.1}
5	fn decompose(self)->Self::Decomposition{(self.start,self.end)}
6	fn decompose_cloned(&self)->Self::Decomposition{(self.start,self.end)}
7	type Decomposition=(usize,usize);
8}
9impl Op for (){
10	type Output=();
11}
12impl<A:?Sized+AI<X,Y>,X,Y> AI<X,Y> for &A{
13	fn forward(&self,input:X)->Y{(**self).forward(input)}
14}
15impl<A:?Sized+AI<X,Y>,X,Y> AI<X,Y> for &mut A{
16	fn forward(&self,input:X)->Y{(**self).forward(input)}
17	fn forward_mut(&mut self,input:X)->Y{(**self).forward_mut(input)}
18}
19impl<A:?Sized+Op<Output=Y>,Y> Op for &A{
20	type Output=Y;
21}
22impl<A:?Sized+Op<Output=Y>,Y> Op for &mut A{
23	type Output=Y;
24}
25impl<A:AI<X,X>+Op<Output=X>,X> Op for Option<A>{
26	type Output=X;
27}
28impl<A:AI<X,X>,X> AI<X,X> for Option<A>{
29	fn forward(&self,x:X)->X{
30		if let Some(a)=self{a.forward(x)}else{x}
31	}
32	fn forward_mut(&mut self,x:X)->X{
33		if let Some(a)=self{a.forward_mut(x)}else{x}
34	}
35}
36impl<A:AI<X,Y>,X,Y> AI<X,Y> for Inner<A>{
37	fn forward(&self,input:X)->Y{self.0.forward(input)}
38	fn forward_mut(&mut self,input:X)->Y{self.0.forward_mut(input)}
39}
40impl<A:Decompose> Decompose for Inner<A>{
41	fn compose(decomposition:Self::Decomposition)->Self{Self(A::compose(decomposition))}
42	fn decompose(self)->Self::Decomposition{self.0.decompose()}
43	fn decompose_cloned(&self)->Self::Decomposition{self.0.decompose_cloned()}
44	type Decomposition=A::Decomposition;
45}
46impl<A:Decompose> Decompose for Option<A>{
47	fn compose(decomposition:Self::Decomposition)->Self{decomposition.map(A::compose)}
48	fn decompose(self)->Self::Decomposition{self.map(A::decompose)}
49	fn decompose_cloned(&self)->Self::Decomposition{self.as_ref().map(A::decompose_cloned)}
50	type Decomposition=Option<A::Decomposition>;
51}
52impl<A:Decompose> Decompose for Vec<A>{
53	fn compose(decomposition:Self::Decomposition)->Self{decomposition.into_iter().map(A::compose).collect()}
54	fn decompose(self)->Self::Decomposition{self.into_iter().map(A::decompose).collect()}
55	fn decompose_cloned(&self)->Self::Decomposition{self.iter().map(A::decompose_cloned).collect()}
56	type Decomposition=Vec<A::Decomposition>;
57}
58impl<A:IntoSequence<M>,M:AI<M::Output,M::Output>+Op> IntoSequence<M> for Inner<A>{
59	fn into_sequence(self)->Sequential<Vec<M>>{self.0.into_sequence()}
60}
61impl<A:Op> Op for Inner<A>{
62	type Output=A::Output;
63}
64impl<A> From<A> for Inner<A>{
65	fn from(inner:A)->Self{Self(inner)}
66}
67impl<A> Inner<A>{
68	/// references the inner value
69	pub fn inner(&self)->&A{&self.0}
70	/// references the inner value
71	pub fn inner_mut(&mut self)->&mut A{&mut self.0}
72	/// converts into the inner value
73	pub fn into_inner(self)->A{self.0}
74}
75impl<A> Op for [A]{
76	type Output=();
77}
78impl<A> Op for Vec<A>{
79	type Output=();
80}
81impl<A> UnwrapInner for Inner<A>{
82	fn unwrap_inner(self)->Self::Inner{self.0}
83	type Inner=A;
84}
85impl<K:Decompose+Eq+Hash,V:Decompose,S:Default+BuildHasher> Decompose for HashMap<K,V,S> where K::Decomposition:Ord{
86	fn compose(decomposition:Self::Decomposition)->Self{decomposition.into_iter().map(Decompose::compose).collect()}
87	fn decompose(self)->Self::Decomposition{
88		let mut v:Vec<_>=self.into_iter().map(Decompose::decompose).collect();
89		v.sort_unstable_by(|(k,_v),(k2,_v2)|k.cmp(k2));
90		v
91	}
92	fn decompose_cloned(&self)->Self::Decomposition{
93		let mut v:Vec<_>=self.iter().map(|(k,v)|(k.decompose_cloned(),v.decompose_cloned())).collect();
94		v.sort_unstable_by(|(k,_v),(k2,_v2)|k.cmp(k2));
95		v
96	}
97	type Decomposition=Vec<(K::Decomposition,V::Decomposition)>;
98}
99impl<X:Into<Y>,Y> AI<X,Y> for (){
100	fn forward(&self,input:X)->Y{input.into()}
101}
102/// implements decompose for primitive types
103macro_rules! decompose_primitive{
104	($($type:ty),*)=>($(impl Decompose for $type{
105		fn compose(decomposition:Self::Decomposition)->Self{decomposition}
106		fn decompose(self)->Self::Decomposition{self}
107		fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
108		type Decomposition=Self;
109	})*);
110}
111macro_rules! decompose_tuple{
112	($(($($type:ident),+)),*)=>($(impl<$($type:Decompose),+> Decompose for ($($type),+){
113		#[allow(non_snake_case)]
114		fn compose(($($type),+):Self::Decomposition)->Self{($(Decompose::compose($type)),+)}
115		#[allow(non_snake_case)]
116		fn decompose(self)->Self::Decomposition{
117			let ($($type),+)=self;
118			($($type.decompose()),+)
119		}
120		#[allow(non_snake_case)]
121		fn decompose_cloned(&self)->Self::Decomposition{
122			let ($($type),+)=self;
123			($($type.decompose_cloned()),+)
124		}
125		type Decomposition=($($type::Decomposition),+);
126	})*);
127}
128/// implements op for tuples
129macro_rules! op_tuple{
130	($(($($type:ident),+)),*)=>($(impl<$($type:Op),+> Op for ($($type),+){
131		type Output=();
132	})*);
133}
134op_tuple!((A,B),(A,B,C),(A,B,C,D),(A,B,C,D,E),(A,B,C,D,E,F),(A,B,C,D,E,F,G),(A,B,C,D,E,F,G,H));
135#[derive(Clone,Copy,Debug,Default,Eq,Hash,Ord,PartialEq,PartialOrd)]
136#[repr(transparent)]
137/// wraps an inner value so it can be unwrapped with unwrap inner
138pub struct Inner<A>(pub A);
139/// general ai trait
140pub trait AI<X,Y>{
141	/// applies to the input
142	fn forward(&self,input:X)->Y;
143	/// applies to the input, possibly updating internal caches
144	fn forward_mut(&mut self,input:X)->Y{self.forward(input)}
145}
146/// trait to decompose AI modules into components that implement other libraries' traits
147pub trait Decompose{// TODO derive macros, make decompose cloned and decompose take and into sequence cloned and into sequence take
148	/// recreates from the decomposition
149	fn compose(decomposition:Self::Decomposition)->Self where Self:Sized;
150	/// owned decomposition
151	fn decompose(self)->Self::Decomposition where Self:Sized;
152	/// decomposition that copies data
153	fn decompose_cloned(&self)->Self::Decomposition;
154	/// the decomposed type
155	type Decomposition;
156}
157/// conversion from a composite module into a sequential list of modules
158pub trait IntoSequence<M:AI<M::Output,M::Output>+Op>{
159	/// converts into a sequential module list
160	fn into_sequence(self)->Sequential<Vec<M>>;
161}
162/// composition trait
163pub trait Op{
164	/// wraps with a softmax operation
165	fn abnormal_softmax(self,temperature:f32)->AbnormalSoftmax<Self> where Self:Sized,AbnormalSoftmax<Self>:Op{AbnormalSoftmax::new(self,temperature)}
166	/// wraps with an absolute value operation
167	fn abs(self)->Abs<Self> where Self:Sized,Abs<Self>:Op{Abs::new(self)}
168	/// wraps with a accq operation
169	fn acc_q(self,gamma:f32)->AccQ<Self> where AccQ<Self>:Op,Self:Sized{AccQ::new(gamma,self)}
170	/// wraps with a cat operation
171	fn cat(self,dim:i32)->Cat<Self> where Cat<Self>:Op,Self:Sized{Cat::new(dim,self)}
172	/// sequences with another ai operation
173	fn chain<B>(self,b:B)->Sequential<(Self,B)> where Self:Sized,Sequential<(Self,B)>:Op{Sequential::new((self,b))}
174	/// wraps with a cross entropy operation. If temperature is a number it will be used to apply softmax to the logits before computing entropy with the target. if the input will already be a probability distribution instead of logits, put NaN temperature
175	fn cross_entropy(self,temperature:f32)->CrossEntropy<Self> where CrossEntropy<Self>:Op,Self:Sized{CrossEntropy::new(self,temperature)}
176	/// wraps with a duplicate operation
177	fn duplicate(self)->Duplicate<Self> where Duplicate<Self>:Op,Self:Sized{Duplicate::new(self)}
178	/// set type but with the same input and output
179	fn fix_type<Z>(self)->SetType<Self,Z,Z> where Self:AI<Z,Z>+Sized{self.set_type()}
180	/// applies to the input
181	fn forward_fixed<Z>(&self,input:Z)->Z where Self:AI<Z,Z>+Sized{self.forward(input)}
182	/// applies to the input
183	fn forward_fixed_mut<Z>(&mut self,input:Z)->Z where Self:AI<Z,Z>+Sized{self.forward(input)}
184	/// applies to the input
185	fn forward_typed<W,Z>(&self,input:W)->Z where Self:AI<W,Z>+Sized{self.forward(input)}
186	/// applies to the input, possibly updating internal caches
187	fn forward_typed_mut<W,Z>(&mut self,input:W)->Z where Self:AI<W,Z>+Sized{self.forward(input)}
188	/// creates an autoregressive inference
189	fn infer_autoregressive<X,Y>(self,input:X)->Autoregression<Self,Y> where Self:AI<X,Y>+AI<Y,Y>+Sized,Y:Clone{Autoregression::new(self,input)}
190	/// wraps with a softmax operation
191	fn log_softmax(self,temperature:f32)->LogSoftmax<Self> where Self:Sized,LogSoftmax<Self>:Op{LogSoftmax::new(self,temperature)}
192	/// applies the operation to every output
193	fn map<B>(self,b:B)->Map<Sequential<(Self,B)>> where Map<Sequential<(Self,B)>>:Op,Self:Sized,Sequential<(Self,B)>:Op{self.chain(b).to_each()}
194	/// wraps with a mean operation
195	fn mean(self)->Mean<Self> where Mean<Self>:Op,Self:Sized{Mean::new(self)}
196	/// creates an optional operation
197	fn optional(self)->Option<Self> where Self:Sized{Some(self)}
198	/// produces a zip module
199	fn separately(self)->Zip<Self> where Self:Sized,Zip<Self>:Op{Zip::new(self)}
200	/// produces a sequential module
201	fn sequential(self)->Sequential<Self> where Self:Sized,Sequential<Self>:Op{Sequential::new(self)}
202	/// sets the input output types
203	fn set_type<W,Z>(self)->SetType<Self,W,Z> where Self:AI<W,Z>+Sized{SetType::new(self)}
204	/// wraps with a choose operation
205	fn soft_choose(self,temperature:f32)->Choose<Self> where Self:Sized,Choose<Self>:Op{Choose::new(self,temperature)}
206	/// wraps with a softmax operation
207	fn softmax(self,temperature:f32)->Softmax<Self> where Self:Sized,Softmax<Self>:Op{Softmax::new(self,temperature)}
208	/// wraps with a mse operation
209	fn squared_error(self)->SquaredError<Self> where SquaredError<Self>:Op,Self:Sized{SquaredError::new(self)}
210	/// wraps with a squeeze operation
211	fn squeeze(self,dim:i32)->Squeeze<Self> where Squeeze<Self>:Op,Self:Sized{Squeeze::new(dim,self)}
212	/// wraps with a stack operation
213	fn stack(self,dim:i32)->Stack<Self> where Stack<Self>:Op,Self:Sized{Stack::new(dim,self)}
214	/// wraps with a map operation
215	fn to_each(self)->Map<Self> where Map<Self>:Op,Self:Sized{Map::new(self)}
216	/// wraps with a unsqueeze operation
217	fn unsqueeze(self,dim:i32)->Unsqueeze<Self> where Unsqueeze<Self>:Op,Self:Sized{Unsqueeze::new(dim,self)}
218	/// wraps with a sum operation
219	fn sum(self)->Sum<Self> where Sum<Self>:Op,Self:Sized{Sum::new(self)}
220	/// wraps the inner value so it can be unwrapped with unwrap inner
221	fn wrap_inner(self)->Inner<Self> where Self:Sized{Inner(self)}
222	/// zips with another ai operation
223	fn zip<B>(self,b:B)->Zip<(Self,B)> where Self:Sized,Zip<(Self,B)>:Op{Zip::new((self,b))}
224	/// suggested output type to help with composition coherence. Ideally, Self should implement AI<X,Self::Output> for some X
225	type Output;
226}
227/// trait for unwrapping nested wrapped values
228pub trait UnwrapInner{
229	/// unwraps the inner value
230	fn unwrap_inner(self)->Self::Inner;
231	/// the inner type
232	type Inner;
233}
234use {op_tuple,decompose_primitive,decompose_tuple};
235use crate::builtin::{
236	Autoregression,Duplicate,Map,Sequential,SetType,Zip,math::{Abs,Mean,SquaredError,Sum},reinforcement::AccQ,soft::{AbnormalSoftmax,Choose,CrossEntropy,LogSoftmax,Softmax},structural::{Cat,Squeeze,Stack,Unsqueeze}
237};
238use std::{
239	collections::HashMap,cmp::Ord,hash::{BuildHasher,Hash},ops::Range
240};