use crate::checker::EnvResolve;
use crate::store::Store;
use syntax::ast::{Expression, MatchArm, Pattern, SelectArm, SelectArmPattern, Span};
use syntax::types::{Type, unqualified_name};
use super::super::TaskState;
use crate::validators::temp_producing::is_temp_producing;
impl TaskState<'_> {
pub(super) fn infer_select(
&mut self,
store: &mut Store,
arms: Vec<SelectArm>,
span: Span,
expected_ty: &Type,
) -> Expression {
if arms.is_empty() {
self.sink.push(diagnostics::infer::empty_select(span));
self.unify(store, expected_ty, &Type::unit(), &span);
return Expression::Select {
arms: vec![],
ty: expected_ty.resolve_in(&self.env),
span,
};
}
self.check_multiple_select_receives(&arms);
self.check_duplicate_select_defaults(&arms);
let result_ty = self.new_type_var();
self.unify(store, expected_ty, &result_ty, &span);
let needs_reconciliation = result_ty.resolve_in(&self.env).is_variable();
let mut arm_target_types: Vec<Type> = if needs_reconciliation {
Vec::with_capacity(arms.len())
} else {
Vec::new()
};
let new_arms: Vec<SelectArm> = arms
.into_iter()
.map(|arm| {
self.scopes.push();
let independent_ty;
let arm_target = if needs_reconciliation {
independent_ty = self.new_type_var();
&independent_ty
} else {
&result_ty
};
let new_arm_pattern = match arm.pattern {
SelectArmPattern::Receive {
binding,
receive_expression,
body,
..
} => self.infer_select_receive(
store,
binding,
receive_expression,
body,
arm_target,
),
SelectArmPattern::Send {
send_expression,
body,
} => self.infer_select_send(store, send_expression, body, arm_target),
SelectArmPattern::MatchReceive {
receive_expression,
arms: match_arms,
} => self.infer_select_match_receive(
store,
receive_expression,
match_arms,
arm_target,
),
SelectArmPattern::WildCard { body } => {
self.infer_select_wildcard(store, body, arm_target)
}
};
if needs_reconciliation {
arm_target_types.push(arm_target.clone());
}
self.scopes.pop();
SelectArm {
pattern: new_arm_pattern,
}
})
.collect();
if needs_reconciliation {
self.reconcile_and_unify(store, &result_ty, &arm_target_types, &span);
}
let resolved_result = result_ty.resolve_in(&self.env);
let shorthand_receive_count = new_arms
.iter()
.filter(|arm| matches!(arm.pattern, SelectArmPattern::Receive { .. }))
.count();
let has_default = new_arms
.iter()
.any(|arm| matches!(arm.pattern, SelectArmPattern::WildCard { .. }));
if !expected_ty.is_ignored()
&& !resolved_result.is_unit()
&& !resolved_result.is_variable()
&& shorthand_receive_count == 1
&& !has_default
{
self.sink
.push(diagnostics::infer::non_exhaustive_select_expression(span));
}
Expression::Select {
arms: new_arms,
ty: result_ty,
span,
}
}
fn infer_select_receive(
&mut self,
store: &mut Store,
binding: Box<Pattern>,
receive_expression: Box<Expression>,
body: Box<Expression>,
result_ty: &Type,
) -> SelectArmPattern {
let receive_ty = self.new_type_var();
let new_receive_expression = self.infer_expression(store, *receive_expression, &receive_ty);
self.check_complex_select_expression(&new_receive_expression);
let element_ty = if self.is_channel_receive_call(&new_receive_expression) {
receive_ty.clone()
} else {
self.sink.push(diagnostics::infer::expected_channel_receive(
&receive_ty,
new_receive_expression.get_span(),
));
Type::Error
};
let inner_binding: &Pattern = match binding.as_ref() {
Pattern::AsBinding { pattern, span, .. } => {
let is_some = matches!(pattern.as_ref(), Pattern::EnumVariant { identifier, .. }
if unqualified_name(identifier) == "Some");
if is_some {
self.sink
.push(diagnostics::infer::select_some_as_binding_not_supported(
*span,
));
} else {
self.sink
.push(diagnostics::infer::as_binding_in_irrefutable_context(*span));
}
pattern.as_ref()
}
p => p,
};
if matches!(inner_binding, Pattern::Identifier { .. }) {
self.sink
.push(diagnostics::infer::bare_identifier_in_select_receive(
&binding.get_span(),
));
}
if let Pattern::EnumVariant {
identifier, fields, ..
} = inner_binding
{
let variant_name = unqualified_name(identifier);
if variant_name == "None" {
self.sink
.push(diagnostics::infer::none_pattern_in_select_receive(
&binding.get_span(),
));
}
if variant_name == "Some"
&& fields.len() == 1
&& !TaskState::is_irrefutable_select_pattern(&fields[0])
{
self.sink
.push(diagnostics::infer::select_receive_refutable_pattern(
fields[0].get_span(),
));
}
}
let (new_binding, typed_pattern) = self.infer_pattern(
store,
*binding,
element_ty.clone(),
syntax::ast::BindingKind::Let { mutable: false },
);
let saved_in_match_arm = self.scopes.set_in_match_arm(true);
self.scopes.set_in_subexpression(false);
let new_body = self.infer_expression(store, *body, result_ty);
self.scopes.set_in_match_arm(saved_in_match_arm);
SelectArmPattern::Receive {
binding: Box::new(new_binding),
typed_pattern: Some(typed_pattern),
receive_expression: Box::new(new_receive_expression),
body: Box::new(new_body),
}
}
fn infer_select_send(
&mut self,
store: &mut Store,
send_expression: Box<Expression>,
body: Box<Expression>,
result_ty: &Type,
) -> SelectArmPattern {
let send_ty = self.new_type_var();
let new_send_expression = self.infer_expression(store, *send_expression, &send_ty);
self.check_complex_select_expression(&new_send_expression);
if !self.is_channel_send_call(&new_send_expression)
&& !self.is_channel_receive_call(&new_send_expression)
{
self.sink.push(diagnostics::infer::expected_channel_send(
new_send_expression.get_span(),
));
}
let saved_in_match_arm = self.scopes.set_in_match_arm(true);
self.scopes.set_in_subexpression(false);
let new_body = self.infer_expression(store, *body, result_ty);
self.scopes.set_in_match_arm(saved_in_match_arm);
SelectArmPattern::Send {
send_expression: Box::new(new_send_expression),
body: Box::new(new_body),
}
}
fn infer_select_match_receive(
&mut self,
store: &mut Store,
receive_expression: Box<Expression>,
match_arms: Vec<MatchArm>,
result_ty: &Type,
) -> SelectArmPattern {
let receive_ty = self.new_type_var();
let new_receive_expression = self.infer_expression(store, *receive_expression, &receive_ty);
self.check_complex_select_expression(&new_receive_expression);
if !self.is_channel_receive_call(&new_receive_expression) {
self.sink.push(diagnostics::infer::expected_channel_receive(
&receive_ty,
new_receive_expression.get_span(),
));
}
self.check_select_match_arms(&match_arms, new_receive_expression.get_span());
let pattern_ty = receive_ty.resolve_in(&self.env);
let needs_reconciliation = result_ty.resolve_in(&self.env).is_variable();
let mut arm_expression_types: Vec<Type> = if needs_reconciliation {
Vec::with_capacity(match_arms.len())
} else {
Vec::new()
};
let new_match_arms: Vec<MatchArm> = match_arms
.into_iter()
.map(|match_arm| {
self.scopes.push();
let (new_pattern, typed_pattern) = self.infer_pattern(
store,
match_arm.pattern,
pattern_ty.clone(),
syntax::ast::BindingKind::MatchArm,
);
let bool_ty = self.type_bool();
let new_guard = match_arm.guard.map(|guard| {
let guard_expression = self.infer_expression(store, *guard, &bool_ty);
Box::new(guard_expression)
});
let independent_ty;
let arm_expected = if needs_reconciliation {
independent_ty = self.new_type_var();
&independent_ty
} else {
result_ty
};
let saved_in_match_arm = self.scopes.set_in_match_arm(true);
self.scopes.set_in_subexpression(false);
let new_expression =
self.infer_expression(store, *match_arm.expression, arm_expected);
self.scopes.set_in_match_arm(saved_in_match_arm);
if needs_reconciliation {
arm_expression_types.push(arm_expected.clone());
}
self.scopes.pop();
MatchArm {
pattern: new_pattern,
guard: new_guard,
typed_pattern: Some(typed_pattern),
expression: Box::new(new_expression),
}
})
.collect();
if needs_reconciliation {
let span = new_receive_expression.get_span();
self.reconcile_and_unify(store, result_ty, &arm_expression_types, &span);
}
SelectArmPattern::MatchReceive {
receive_expression: Box::new(new_receive_expression),
arms: new_match_arms,
}
}
fn infer_select_wildcard(
&mut self,
store: &mut Store,
body: Box<Expression>,
result_ty: &Type,
) -> SelectArmPattern {
let saved_in_match_arm = self.scopes.set_in_match_arm(true);
self.scopes.set_in_subexpression(false);
let new_body = self.infer_expression(store, *body, result_ty);
self.scopes.set_in_match_arm(saved_in_match_arm);
SelectArmPattern::WildCard {
body: Box::new(new_body),
}
}
pub(crate) fn is_channel_receive_call(&self, expression: &Expression) -> bool {
if let Expression::Call {
expression, args, ..
} = expression
{
if args.is_empty()
&& let Expression::DotAccess {
member,
expression: receiver,
..
} = expression.as_ref()
&& member == "receive"
{
return self.is_channel_type(&receiver.get_type());
}
if args.len() == 1
&& let Expression::Identifier { value, .. } = expression.as_ref()
&& value.ends_with(".receive")
&& Self::is_ufcs_channel_prefix(value)
&& self.is_channel_type(&args[0].get_type())
{
return true;
}
}
false
}
pub(crate) fn is_channel_send_call(&self, expression: &Expression) -> bool {
if let Expression::Call {
expression, args, ..
} = expression
{
if let Expression::DotAccess {
member,
expression: receiver,
..
} = expression.as_ref()
&& member == "send"
&& args.len() == 1
&& self.is_channel_type(&receiver.get_type())
{
return true;
}
if args.len() == 2
&& let Expression::Identifier { value, .. } = expression.as_ref()
&& value.ends_with(".send")
&& Self::is_ufcs_channel_prefix(value)
&& self.is_channel_type(&args[0].get_type())
{
return true;
}
}
false
}
fn is_ufcs_channel_prefix(identifier: &str) -> bool {
if let Some((prefix, _method)) = identifier.rsplit_once('.') {
let base = prefix.rsplit_once('.').map(|(_, b)| b).unwrap_or(prefix);
matches!(base, "Channel" | "Sender" | "Receiver")
} else {
false
}
}
fn is_channel_type(&self, ty: &Type) -> bool {
let resolved = ty.resolve_in(&self.env).strip_refs();
matches!(resolved.get_name(), Some("Channel" | "Sender" | "Receiver"))
}
fn extract_channel_expression(expression: &Expression) -> Option<&Expression> {
let Expression::Call {
expression, args, ..
} = expression
else {
return None;
};
if let Expression::DotAccess {
expression: channel,
member,
..
} = expression.as_ref()
&& (member == "send" || member == "receive")
{
return Some(channel);
}
if let Expression::Identifier { value, .. } = expression.as_ref()
&& (value.ends_with(".send") || value.ends_with(".receive"))
&& !args.is_empty()
{
return Some(&args[0]);
}
None
}
fn extract_send_value(expression: &Expression) -> Option<&Expression> {
let Expression::Call {
expression, args, ..
} = expression
else {
return None;
};
if let Expression::DotAccess { member, .. } = expression.as_ref()
&& member == "send"
&& args.len() == 1
{
return Some(&args[0]);
}
if let Expression::Identifier { value, .. } = expression.as_ref()
&& value.ends_with(".send")
&& args.len() == 2
{
return Some(&args[1]);
}
None
}
fn check_complex_select_expression(&mut self, expression: &Expression) {
if let Some(channel) = Self::extract_channel_expression(expression)
&& is_temp_producing(channel)
{
self.sink
.push(diagnostics::infer::complex_select_expression(
channel.get_span(),
));
}
if let Some(value) = Self::extract_send_value(expression)
&& is_temp_producing(value)
{
self.sink
.push(diagnostics::infer::complex_select_expression(
value.get_span(),
));
}
}
}