use std::fmt;
use std::fmt::Debug;
use std::fmt::Display;
use std::sync::Arc;
use allocative::Allocative;
use dupe::Dupe;
use dupe::IterDupedExt;
use either::Either;
use pagable::PagableRegisteredFor;
use starlark_derive::Trace;
use starlark_syntax::codemap::CodeMap;
use starlark_syntax::codemap::Span;
use starlark_syntax::codemap::Spanned;
use crate as starlark;
use crate::eval::compiler::small_vec_1::SmallVec1;
use crate::typing::ParamSpec;
use crate::typing::TypingOracleCtx;
use crate::typing::arc_ty::ArcTy;
use crate::typing::basic::TyBasic;
use crate::typing::call_args::TyCallArgs;
use crate::typing::callable::TyCallable;
use crate::typing::custom::TyCustom;
use crate::typing::custom::TyCustomDyn;
use crate::typing::custom::TyCustomImpl;
use crate::typing::error::TypingNoContextError;
use crate::typing::function::TyCustomFunction;
use crate::typing::function::TyCustomFunctionImpl;
use crate::typing::function::TyFunction;
use crate::typing::small_arc_vec::SmallArcVec1;
use crate::typing::starlark_value::TyStarlarkValue;
use crate::typing::structs::TyStruct;
use crate::typing::tuple::TyTuple;
use crate::values::StarlarkValue;
use crate::values::Value;
use crate::values::bool::StarlarkBool;
use crate::values::typing::never::TypingNever;
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct Approximation {
pub category: &'static str,
pub message: String,
}
impl Approximation {
pub fn new(category: &'static str, message: impl Debug) -> Self {
Self {
category,
message: format!("{message:?}"),
}
}
}
impl Display for Approximation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Approximation: {} = {:?}", self.category, self.message)
}
}
#[derive(
Debug, Clone, Dupe, PartialEq, Eq, Hash, PartialOrd, Ord, Allocative, Trace
)]
#[derive(pagable::Pagable)]
pub struct Ty {
alternatives: SmallArcVec1<TyBasic>,
}
fn merge_adjacent<T>(xs: Vec<T>, f: impl Fn(T, T) -> Either<T, (T, T)>) -> SmallVec1<T> {
let mut res = SmallVec1::new();
let mut last = None;
for x in xs {
match last {
None => last = Some(x),
Some(l) => match f(l, x) {
Either::Left(x) => last = Some(x),
Either::Right((l, x)) => {
res.push(l);
last = Some(x)
}
},
}
}
if let Some(l) = last {
res.push(l)
}
res
}
impl Ty {
pub fn todo() -> Self {
Ty::any()
}
pub fn as_name(&self) -> Option<&str> {
match self.alternatives.as_slice() {
[x] => x.as_name(),
_ => None,
}
}
pub(crate) fn is_starlark_value(&self) -> Option<TyStarlarkValue> {
match self.iter_union() {
[TyBasic::StarlarkValue(x)] => Some(*x),
_ => None,
}
}
pub(crate) const fn basic(basic: TyBasic) -> Self {
Ty {
alternatives: SmallArcVec1::one(basic),
}
}
pub const fn any() -> Self {
Ty::basic(TyBasic::Any)
}
pub const fn never() -> Self {
Ty {
alternatives: SmallArcVec1::empty(),
}
}
pub const fn none() -> Self {
Ty::basic(TyBasic::none())
}
pub const fn bool() -> Self {
Ty::starlark_value::<StarlarkBool>()
}
pub const fn int() -> Self {
Ty::basic(TyBasic::int())
}
pub const fn float() -> Self {
Ty::basic(TyBasic::float())
}
pub const fn string() -> Self {
Ty::basic(TyBasic::string())
}
pub fn list(element: Ty) -> Self {
Ty::basic(TyBasic::list(element))
}
pub(crate) fn any_list() -> Self {
Self::list(Ty::any())
}
pub fn iter(item: Ty) -> Self {
Ty::basic(TyBasic::iter(item))
}
pub fn dict(key: Ty, value: Ty) -> Self {
Ty::basic(TyBasic::dict(key, value))
}
pub(crate) fn any_dict() -> Self {
Self::dict(Ty::any(), Ty::any())
}
pub fn set(item: Ty) -> Self {
Ty::basic(TyBasic::set(item))
}
pub(crate) fn any_set() -> Self {
Self::set(Ty::any())
}
pub fn tuple2(a: Ty, b: Ty) -> Self {
Ty::tuple(vec![a, b])
}
pub fn tuple(elems: Vec<Ty>) -> Self {
Ty::basic(TyBasic::Tuple(TyTuple::Elems(elems.into())))
}
pub(crate) fn any_tuple() -> Self {
Self::tuple_of(Ty::any())
}
pub(crate) fn tuple_of(item: Ty) -> Self {
Ty::basic(TyBasic::Tuple(TyTuple::Of(ArcTy::new(item))))
}
pub fn function(params: ParamSpec, result: Ty) -> Self {
Self::ty_function(TyFunction::new(params, result))
}
pub fn callable(params: ParamSpec, result: Ty) -> Self {
Ty::basic(TyBasic::Callable(TyCallable::new(params, result)))
}
pub fn ty_function(f: TyFunction) -> Self {
Self::custom(TyCustomFunction(f))
}
pub fn ctor_function(type_attr: Ty, params: ParamSpec, result: Ty) -> Self {
Self::custom(TyCustomFunction(TyFunction::new_with_type_attr(
params, result, type_attr,
)))
}
pub(crate) fn any_callable() -> Self {
Ty::basic(TyBasic::Callable(TyCallable::any()))
}
pub(crate) fn any_struct() -> Self {
Ty::custom(TyStruct::any())
}
pub const fn starlark_value<'v, T: StarlarkValue<'v>>() -> Self {
Ty::basic(TyBasic::starlark_value::<T>())
}
pub(crate) fn is_any(&self) -> bool {
self == &Ty::any()
}
pub(crate) fn is_never(&self) -> bool {
self.alternatives.is_empty()
}
pub(crate) fn is_list(&self) -> bool {
matches!(self.alternatives.as_slice(), [TyBasic::List(_)])
}
pub(crate) fn is_function(&self) -> bool {
self.as_name() == Some("function")
}
pub fn as_function(&self) -> Option<&TyFunction> {
match self.iter_union() {
[x] => x.as_function(),
_ => None,
}
}
pub fn unions(xs: Vec<Self>) -> Self {
if xs.iter().any(|x| x.is_any()) {
return Ty::any();
}
fn next_skip_never<I: Iterator<Item = Ty>>(iter: &mut I) -> Option<Ty> {
iter.find(|x| !x.is_never())
}
let mut xs = xs.into_iter();
let Some(x0) = next_skip_never(&mut xs) else {
return Ty::never();
};
let Some(x1) = next_skip_never(&mut xs) else {
return x0;
};
if xs.len() == 0 && x0 == x1 {
return x0;
}
let xs = xs.as_slice();
let mut xs: Vec<TyBasic> = [x0, x1]
.iter()
.chain(xs)
.flat_map(|x| x.iter_union())
.duped()
.collect();
xs.sort();
xs.dedup();
let xs = merge_adjacent(xs, |x, y| match (x, y) {
(TyBasic::List(x), TyBasic::List(y)) => {
Either::Left(TyBasic::List(ArcTy::union2(x, y)))
}
(TyBasic::Dict(x_k, x_v), TyBasic::Dict(y_k, y_v)) => Either::Left(TyBasic::Dict(
ArcTy::union2(x_k, y_k),
ArcTy::union2(x_v, y_v),
)),
(TyBasic::Custom(x), TyBasic::Custom(y)) => match TyCustom::union2(x, y) {
Ok(u) => Either::Left(TyBasic::Custom(u)),
Err((x, y)) => Either::Right((TyBasic::Custom(x), TyBasic::Custom(y))),
},
xy => Either::Right(xy),
});
Ty {
alternatives: xs.into_iter().collect(),
}
}
pub fn iter_union(&self) -> &[TyBasic] {
&self.alternatives
}
pub(crate) fn typecheck_union_simple(
&self,
typecheck: impl Fn(&TyBasic) -> Result<Ty, TypingNoContextError>,
) -> Result<Ty, TypingNoContextError> {
if self.is_any() || self.is_never() {
Ok(self.dupe())
} else {
match self.iter_union() {
[x] => typecheck(x),
xs => {
let mut good = Vec::with_capacity(self.alternatives.len());
for basic in xs {
match typecheck(basic) {
Ok(ty) => good.push(ty),
Err(TypingNoContextError) => {}
}
}
if good.is_empty() {
Err(TypingNoContextError)
} else {
Ok(Ty::unions(good))
}
}
}
}
}
pub fn union2(a: Self, b: Self) -> Self {
if a.is_any() || b.is_any() {
Ty::any()
} else if a == b {
a
} else if a.is_never() {
b
} else if b.is_never() {
a
} else {
Ty::unions(vec![a, b])
}
}
pub fn custom(t: impl TyCustomImpl) -> Self {
Ty::basic(TyBasic::Custom(TyCustom(Arc::new(t))))
}
#[expect(
private_bounds,
reason = "TyCustomDyn is intentionally pub(crate); it appears in the bound \
only to enforce typetag registration at compile time and is not \
part of the public API surface callers need to name."
)]
pub fn custom_function<F>(f: F) -> Self
where
F: TyCustomFunctionImpl,
TyCustomFunction<F>: PagableRegisteredFor<dyn TyCustomDyn>,
{
Ty::custom(TyCustomFunction(f))
}
pub fn of_value(value: Value) -> Ty {
match value.get_ref().typechecker_ty() {
Some(t) => t,
_ => value.get_type_starlark_repr(),
}
}
#[must_use]
pub(crate) fn check_call<'a>(
&self,
pos: impl IntoIterator<Item = Ty>,
named: impl IntoIterator<Item = (&'a str, Ty)>,
args: Option<Ty>,
kwargs: Option<Ty>,
expected_return_type: Ty,
) -> bool {
let oracle = TypingOracleCtx {
codemap: CodeMap::empty_static(),
};
let Ok(ret) = oracle.validate_call(
Span::default(),
self,
&TyCallArgs {
pos: pos
.into_iter()
.map(|p| Spanned {
span: Span::default(),
node: p,
})
.collect(),
named: named
.into_iter()
.map(|p| Spanned {
span: Span::default(),
node: p,
})
.collect(),
args: args.map(|t| Spanned {
span: Span::default(),
node: t,
}),
kwargs: kwargs.map(|t| Spanned {
span: Span::default(),
node: t,
}),
},
) else {
return false;
};
let Ok(ok) = oracle.intersects(&ret, &expected_return_type) else {
return false;
};
ok
}
pub(crate) fn check_intersects(&self, other: &Ty) -> crate::Result<bool> {
let oracle = TypingOracleCtx {
codemap: CodeMap::empty_static(),
};
match oracle.intersects(self, other) {
Ok(ok) => Ok(ok),
Err(e) => Err(e.into_error()),
}
}
pub(crate) fn fmt_with_config(
&self,
f: &mut fmt::Formatter<'_>,
config: &TypeRenderConfig,
) -> fmt::Result {
match self.iter_union() {
[] => write!(f, "{}", TypingNever::TYPE),
xs => {
for (i, x) in xs.iter().enumerate() {
if i != 0 {
write!(f, " | ")?;
}
x.fmt_with_config(f, config)?;
}
Ok(())
}
}
}
pub(crate) fn display_with<'a>(&'a self, config: &'a TypeRenderConfig) -> TyDisplay<'a> {
TyDisplay { ty: self, config }
}
}
pub enum TypeRenderConfig {
Default,
LinkedType {
render_linked_ty_starlark_value: Box<dyn Fn(&TyStarlarkValue) -> String>,
},
}
pub(crate) struct TyDisplay<'a> {
ty: &'a Ty,
config: &'a TypeRenderConfig,
}
impl Display for TyDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.ty.fmt_with_config(f, self.config)
}
}
impl Display for Ty {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt_with_config(f, &TypeRenderConfig::Default)
}
}
#[cfg(all(test, feature = "pagable"))]
mod tests {
use pagable::PagableDeserialize;
use pagable::PagableSerialize;
use pagable::testing::TestingDeserializer;
use pagable::testing::TestingSerializer;
use super::*;
use crate::typing::ParamSpec;
use crate::typing::function::TyFunction;
use crate::values::dict::value::FrozenDict;
fn round_trip(ty: &Ty) -> Ty {
let mut ser = TestingSerializer::new();
ty.pagable_serialize(&mut ser).unwrap();
let bytes = ser.finish();
let mut de = TestingDeserializer::new(&bytes);
Ty::pagable_deserialize(&mut de).unwrap()
}
#[track_caller]
fn assert_round_trip(ty: Ty) {
assert_eq!(ty, round_trip(&ty));
}
#[test]
fn test_round_trip_primitives() {
assert_round_trip(Ty::any());
assert_round_trip(Ty::never());
assert_round_trip(Ty::none());
assert_round_trip(Ty::bool());
assert_round_trip(Ty::int());
assert_round_trip(Ty::float());
assert_round_trip(Ty::string());
}
#[test]
fn test_round_trip_containers() {
assert_round_trip(Ty::list(Ty::string()));
assert_round_trip(Ty::dict(Ty::string(), Ty::int()));
assert_round_trip(Ty::set(Ty::int()));
assert_round_trip(Ty::tuple(vec![Ty::int(), Ty::string()]));
assert_round_trip(Ty::iter(Ty::int()));
}
#[test]
fn test_round_trip_callable() {
assert_round_trip(Ty::callable(ParamSpec::empty(), Ty::int()));
}
#[test]
fn test_round_trip_ty_function() {
assert_round_trip(Ty::ty_function(TyFunction::new(
ParamSpec::empty(),
Ty::string(),
)));
}
#[test]
fn test_ty_custom_arc_dedup_roundtrip() {
let ty1 = Ty::ty_function(TyFunction::new(ParamSpec::empty(), Ty::string()));
let ty2 = ty1.dupe();
let mut ser = TestingSerializer::new();
ty1.pagable_serialize(&mut ser).unwrap();
ty2.pagable_serialize(&mut ser).unwrap();
let shared_bytes = ser.finish();
let ty_a = Ty::ty_function(TyFunction::new(ParamSpec::empty(), Ty::string()));
let ty_b = Ty::ty_function(TyFunction::new(ParamSpec::empty(), Ty::string()));
let mut ser = TestingSerializer::new();
ty_a.pagable_serialize(&mut ser).unwrap();
ty_b.pagable_serialize(&mut ser).unwrap();
let distinct_bytes = ser.finish();
assert!(
shared_bytes.len() < distinct_bytes.len(),
"dedup should make shared-Arc encoding smaller: shared={} distinct={}",
shared_bytes.len(),
distinct_bytes.len(),
);
let mut de = TestingDeserializer::new(&shared_bytes);
let restored1 = Ty::pagable_deserialize(&mut de).unwrap();
let restored2 = Ty::pagable_deserialize(&mut de).unwrap();
assert_eq!(ty1, restored1);
assert_eq!(ty2, restored2);
match (
&restored1.alternatives.as_slice()[0],
&restored2.alternatives.as_slice()[0],
) {
(TyBasic::Custom(a), TyBasic::Custom(b)) => {
assert!(
std::sync::Arc::ptr_eq(&a.0, &b.0),
"deduped Arcs should round-trip to the same allocation",
);
}
_ => panic!("expected TyBasic::Custom"),
}
}
#[test]
fn test_round_trip_union() {
let u = Ty::unions(vec![Ty::int(), Ty::string(), Ty::none()]);
assert_round_trip(u);
}
#[test]
fn test_round_trip_starlark_value() {
assert_round_trip(Ty::starlark_value::<FrozenDict>());
}
#[test]
fn test_round_trip_nested() {
let ty = Ty::dict(
Ty::string(),
Ty::list(Ty::tuple(vec![Ty::int(), Ty::none()])),
);
assert_round_trip(ty);
}
#[test]
fn test_multiple_values_in_one_stream() {
let a = Ty::list(Ty::string());
let b = Ty::dict(Ty::int(), Ty::bool());
let mut ser = TestingSerializer::new();
a.pagable_serialize(&mut ser).unwrap();
b.pagable_serialize(&mut ser).unwrap();
let bytes = ser.finish();
let mut de = TestingDeserializer::new(&bytes);
let restored_a = Ty::pagable_deserialize(&mut de).unwrap();
let restored_b = Ty::pagable_deserialize(&mut de).unwrap();
assert_eq!(restored_a, a);
assert_eq!(restored_b, b);
assert_ne!(restored_a, restored_b);
}
}