use bynk_syntax::ast::{BaseType, CommonsItem, FnName, SourceUnit, TypeBody};
use std::path::Path;
use crate::completion::{BUILTIN_STATICS, for_each_unit};
use crate::symbols::type_ref_str;
#[derive(Debug, PartialEq, Eq)]
pub struct CallContext {
pub callee: String,
pub active_param: usize,
pub open_paren: usize,
}
pub fn call_context(text: &str, offset: usize) -> Option<CallContext> {
let prefix = text.get(..offset)?;
let open = innermost_unclosed_paren(prefix)?;
let callee = callee_before(&prefix[..open])?;
let active = top_level_commas(&prefix[open + 1..]);
Some(CallContext {
callee,
active_param: active,
open_paren: open,
})
}
pub fn value_receiver_method(callee: &str) -> Option<(&str, &str)> {
let (recv, method) = callee.rsplit_once('.')?;
let first = recv.chars().next()?;
if (first.is_ascii_lowercase() || first == '_') && !recv.contains('.') {
Some((recv, method))
} else {
None
}
}
pub fn value_receiver_rewrite(
text: &str,
callee: &str,
open_paren: usize,
cursor: usize,
) -> Option<(String, usize)> {
let (recv, _) = value_receiver_method(callee)?;
let callee_start = open_paren.checked_sub(callee.len())?;
let dot = callee_start + recv.len();
let rewritten = format!("{}{}", &text[..dot], &text[cursor..]);
Some((rewritten, dot.saturating_sub(1)))
}
pub fn kernel_method_signature(ty: &bynk_check::checker::Ty, method: &str) -> Option<String> {
bynk_check::kernel_methods::methods_for(ty)
.iter()
.find(|m| m.name == method)
.map(|m| m.signature.to_string())
}
fn innermost_unclosed_paren(prefix: &str) -> Option<usize> {
let b = prefix.as_bytes();
let mut depth = 0i32;
for i in (0..b.len()).rev() {
match b[i] {
b')' | b']' | b'}' => depth += 1,
b'(' => {
if depth == 0 {
return Some(i);
}
depth -= 1;
}
b'[' | b'{' => {
if depth == 0 {
return None;
}
depth -= 1;
}
_ => {}
}
}
None
}
fn top_level_commas(s: &str) -> usize {
let mut depth = 0i32;
let mut n = 0;
for c in s.chars() {
match c {
'(' | '[' | '{' => depth += 1,
')' | ']' | '}' => depth -= 1,
',' if depth == 0 => n += 1,
_ => {}
}
}
n
}
fn callee_before(s: &str) -> Option<String> {
let s = s.trim_end();
let start = s
.rfind(|c: char| !(c.is_alphanumeric() || c == '_' || c == '.'))
.map_or(0, |i| i + 1);
let callee = &s[start..];
if callee.is_empty() || callee.starts_with('.') || callee.ends_with('.') {
return None;
}
Some(callee.to_string())
}
pub fn resolve_label(callee: &str, doc_text: &str, src_root: Option<&Path>) -> Option<String> {
if let Some((recv, member)) = callee.rsplit_once('.') {
if let Some((_, statics)) = BUILTIN_STATICS.iter().find(|(n, _)| *n == recv)
&& let Some((_, sig)) = statics.iter().find(|(n, _)| *n == member)
{
return Some((*sig).to_string());
}
return resolve_qualified(recv, member, doc_text, src_root);
}
match callee {
"Ok" => return Some("Ok(value: T) -> Result[T, E]".to_string()),
"Err" => return Some("Err(error: E) -> Result[T, E]".to_string()),
"Some" => return Some("Some(value: T) -> Option[T]".to_string()),
_ => {}
}
let mut found = None;
for_each_unit(doc_text, src_root, |unit| {
if found.is_some() {
return;
}
let items = match unit {
SourceUnit::Commons(c) => &c.items,
SourceUnit::Context(c) => &c.items,
SourceUnit::Adapter(a) => &a.items,
_ => return,
};
for item in items {
if let CommonsItem::Fn(f) = item
&& let FnName::Free(id) = &f.name
&& id.name == callee
{
let params: Vec<String> = f
.params
.iter()
.map(|p| format!("{}: {}", p.name.name, type_ref_str(&p.type_ref)))
.collect();
found = Some(format!(
"{callee}({}) -> {}",
params.join(", "),
type_ref_str(&f.return_type)
));
return;
}
}
});
found
}
fn resolve_qualified(
recv: &str,
member: &str,
doc_text: &str,
src_root: Option<&Path>,
) -> Option<String> {
let mut out = None;
for_each_unit(doc_text, src_root, |unit| {
if out.is_some() {
return;
}
let items = match unit {
SourceUnit::Commons(c) => &c.items,
SourceUnit::Context(c) => &c.items,
SourceUnit::Adapter(a) => &a.items,
_ => return,
};
for item in items {
match item {
CommonsItem::Type(t)
if t.name.name == recv && (member == "of" || member == "unsafe") =>
{
let base = match &t.body {
TypeBody::Refined { base, .. } | TypeBody::Opaque { base, .. } => {
base_name(*base)
}
_ => return,
};
out = Some(if member == "of" {
format!("of(value: {base}) -> Result[{recv}, ValidationError]")
} else {
format!("unsafe(value: {base}) -> {recv}")
});
return;
}
CommonsItem::Capability(c) if c.name.name == recv => {
if let Some(op) = c.ops.iter().find(|o| o.name.name == member) {
let params: Vec<String> = op
.params
.iter()
.map(|p| format!("{}: {}", p.name.name, type_ref_str(&p.type_ref)))
.collect();
out = Some(format!(
"{member}({}) -> {}",
params.join(", "),
type_ref_str(&op.return_type)
));
return;
}
}
_ => {}
}
}
});
out
}
fn base_name(b: BaseType) -> &'static str {
match b {
BaseType::Int => "Int",
BaseType::Float => "Float",
BaseType::String => "String",
BaseType::Bool => "Bool",
}
}
pub fn param_ranges(label: &str) -> Vec<(usize, usize)> {
let Some(open) = label.find('(') else {
return Vec::new();
};
let mut ranges = Vec::new();
let mut depth = 0i32;
let mut seg_start = open + 1;
let bytes = label.as_bytes();
let mut i = open;
while i < bytes.len() {
match bytes[i] {
b'(' | b'[' | b'{' => depth += 1,
b')' | b']' | b'}' => {
depth -= 1;
if depth == 0 {
push_trimmed(label, seg_start, i, &mut ranges);
break;
}
}
b',' if depth == 1 => {
push_trimmed(label, seg_start, i, &mut ranges);
seg_start = i + 1;
}
_ => {}
}
i += 1;
}
ranges
}
fn push_trimmed(label: &str, start: usize, end: usize, out: &mut Vec<(usize, usize)>) {
let seg = &label[start..end];
let trimmed = seg.trim();
if trimmed.is_empty() {
return;
}
let s = start + (seg.len() - seg.trim_start().len());
out.push((s, s + trimmed.len()));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn call_context_finds_callee_and_active_param() {
let t = " let x = f(a, b";
let ctx = call_context(t, t.len()).unwrap();
assert_eq!(ctx.callee, "f");
assert_eq!(ctx.active_param, 1); }
#[test]
fn innermost_call_wins_in_nested_calls() {
let t = " outer(g(x";
let ctx = call_context(t, t.len()).unwrap();
assert_eq!(ctx.callee, "g");
assert_eq!(ctx.active_param, 0);
let t2 = " outer(g(x), ";
let ctx2 = call_context(t2, t2.len()).unwrap();
assert_eq!(ctx2.callee, "outer");
assert_eq!(ctx2.active_param, 1);
}
#[test]
fn commas_inside_nested_brackets_dont_count() {
let t = " f(g(a, b), ";
assert_eq!(call_context(t, t.len()).unwrap().active_param, 1);
}
#[test]
fn qualified_callee_and_no_call_context() {
let t = " Clock.now(";
assert_eq!(call_context(t, t.len()).unwrap().callee, "Clock.now");
assert!(call_context(" let x = 1", 11).is_none()); assert!(call_context(" xs[", 4).is_none()); }
#[test]
fn builtin_static_and_constructor_labels() {
assert_eq!(
resolve_label("Int.parse", "context a.b\n", None).as_deref(),
Some("parse(s: String) -> Option[Int]")
);
assert!(
resolve_label("Ok", "context a.b\n", None)
.unwrap()
.starts_with("Ok(value")
);
}
#[test]
fn free_fn_and_capability_op_and_refined_labels() {
let doc = "commons m {\n fn add(a: Int, b: Int) -> Int { a }\n}\n";
assert_eq!(
resolve_label("add", doc, None).as_deref(),
Some("add(a: Int, b: Int) -> Int")
);
let cap = "context a.b\n capability Timer { fn after(label: String) -> Effect[Int] }\n";
assert_eq!(
resolve_label("Timer.after", cap, None).as_deref(),
Some("after(label: String) -> Effect[Int]")
);
let refined = "commons m {\n type Email = String where NonEmpty\n}\n";
assert_eq!(
resolve_label("Email.of", refined, None).as_deref(),
Some("of(value: String) -> Result[Email, ValidationError]")
);
}
#[test]
fn value_receiver_callee_detection_and_rewrite() {
assert_eq!(value_receiver_method("xs.fold"), Some(("xs", "fold")));
assert_eq!(value_receiver_method("Int.parse"), None); assert_eq!(value_receiver_method("a.b.fold"), None); assert_eq!(value_receiver_method("bar"), None);
let text = " let r = xs.fold(0, ";
let open = text.find('(').unwrap();
let (rw, off) = value_receiver_rewrite(text, "xs.fold", open, text.len()).unwrap();
assert_eq!(rw, " let r = xs", "the `.fold(0, ` is dropped");
assert_eq!(&text[off..=off], "s", "offset lands inside `xs`");
}
#[test]
fn kernel_method_signature_lookup() {
use bynk_check::checker::Ty;
use bynk_syntax::ast::BaseType;
let list = Ty::List(Box::new(Ty::Base(BaseType::Int)));
assert!(
kernel_method_signature(&list, "fold")
.unwrap()
.starts_with("fold(")
);
let string = Ty::Base(BaseType::String);
assert!(
kernel_method_signature(&string, "split")
.unwrap()
.starts_with("split(")
);
assert!(kernel_method_signature(&string, "nope").is_none());
}
#[test]
fn param_ranges_split_top_level_only() {
let label = "fold(init: U, step: (U, T) -> U) -> U";
let r = param_ranges(label);
assert_eq!(r.len(), 2, "two params, not split inside (U, T): {r:?}");
assert_eq!(&label[r[0].0..r[0].1], "init: U");
assert_eq!(&label[r[1].0..r[1].1], "step: (U, T) -> U");
}
}